{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from argparse import Namespace\n",
    "from collections import Counter\n",
    "import json\n",
    "import re\n",
    "import string\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Vocabulary(object):\n",
    "    \"\"\"Class to process text and extract vocabulary for mapping\"\"\"\n",
    "\n",
    "    def __init__(self, token_to_idx=None):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            token_to_idx (dict): a pre-existing map of tokens to indices\n",
    "        \"\"\"\n",
    "\n",
    "        if token_to_idx is None:\n",
    "            token_to_idx = {}\n",
    "        self._token_to_idx = token_to_idx\n",
    "\n",
    "        self._idx_to_token = {idx: token \n",
    "                              for token, idx in self._token_to_idx.items()}\n",
    "        \n",
    "    def to_serializable(self):\n",
    "        \"\"\" returns a dictionary that can be serialized \"\"\"\n",
    "        return {'token_to_idx': self._token_to_idx}\n",
    "\n",
    "    @classmethod\n",
    "    def from_serializable(cls, contents):\n",
    "        \"\"\" instantiates the Vocabulary from a serialized dictionary \"\"\"\n",
    "        return cls(**contents)\n",
    "\n",
    "    def add_token(self, token):\n",
    "        \"\"\"Update mapping dicts based on the token.\n",
    "\n",
    "        Args:\n",
    "            token (str): the item to add into the Vocabulary\n",
    "        Returns:\n",
    "            index (int): the integer corresponding to the token\n",
    "        \"\"\"\n",
    "        if token in self._token_to_idx:\n",
    "            index = self._token_to_idx[token]\n",
    "        else:\n",
    "            index = len(self._token_to_idx)\n",
    "            self._token_to_idx[token] = index\n",
    "            self._idx_to_token[index] = token\n",
    "        return index\n",
    "            \n",
    "    def add_many(self, tokens):\n",
    "        \"\"\"Add a list of tokens into the Vocabulary\n",
    "        \n",
    "        Args:\n",
    "            tokens (list): a list of string tokens\n",
    "        Returns:\n",
    "            indices (list): a list of indices corresponding to the tokens\n",
    "        \"\"\"\n",
    "        return [self.add_token(token) for token in tokens]\n",
    "\n",
    "    def lookup_token(self, token):\n",
    "        \"\"\"Retrieve the index associated with the token \n",
    "        \n",
    "        Args:\n",
    "            token (str): the token to look up \n",
    "        Returns:\n",
    "            index (int): the index corresponding to the token\n",
    "        \"\"\"\n",
    "        return self._token_to_idx[token]\n",
    "\n",
    "    def lookup_index(self, index):\n",
    "        \"\"\"Return the token associated with the index\n",
    "        \n",
    "        Args: \n",
    "            index (int): the index to look up\n",
    "        Returns:\n",
    "            token (str): the token corresponding to the index\n",
    "        Raises:\n",
    "            KeyError: if the index is not in the Vocabulary\n",
    "        \"\"\"\n",
    "        if index not in self._idx_to_token:\n",
    "            raise KeyError(\"the index (%d) is not in the Vocabulary\" % index)\n",
    "        return self._idx_to_token[index]\n",
    "\n",
    "    def __str__(self):\n",
    "        return \"<Vocabulary(size=%d)>\" % len(self)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._token_to_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SequenceVocabulary(Vocabulary):\n",
    "    def __init__(self, token_to_idx=None, unk_token=\"<UNK>\",\n",
    "                 mask_token=\"<MASK>\", begin_seq_token=\"<BEGIN>\",\n",
    "                 end_seq_token=\"<END>\"):\n",
    "\n",
    "        super(SequenceVocabulary, self).__init__(token_to_idx)\n",
    "\n",
    "        self._mask_token = mask_token\n",
    "        self._unk_token = unk_token\n",
    "        self._begin_seq_token = begin_seq_token\n",
    "        self._end_seq_token = end_seq_token\n",
    "\n",
    "        self.mask_index = self.add_token(self._mask_token)\n",
    "        self.unk_index = self.add_token(self._unk_token)\n",
    "        self.begin_seq_index = self.add_token(self._begin_seq_token)\n",
    "        self.end_seq_index = self.add_token(self._end_seq_token)\n",
    "\n",
    "    def to_serializable(self):\n",
    "        contents = super(SequenceVocabulary, self).to_serializable()\n",
    "        contents.update({'unk_token': self._unk_token,\n",
    "                         'mask_token': self._mask_token,\n",
    "                         'begin_seq_token': self._begin_seq_token,\n",
    "                         'end_seq_token': self._end_seq_token})\n",
    "        return contents\n",
    "\n",
    "    def lookup_token(self, token):\n",
    "        \"\"\"Retrieve the index associated with the token \n",
    "          or the UNK index if token isn't present.\n",
    "        \n",
    "        Args:\n",
    "            token (str): the token to look up \n",
    "        Returns:\n",
    "            index (int): the index corresponding to the token\n",
    "        Notes:\n",
    "            `unk_index` needs to be >=0 (having been added into the Vocabulary) \n",
    "              for the UNK functionality \n",
    "        \"\"\"\n",
    "        if self.unk_index >= 0:\n",
    "            return self._token_to_idx.get(token, self.unk_index)\n",
    "        else:\n",
    "            return self._token_to_idx[token]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NMTVectorizer(object):\n",
    "    \"\"\" The Vectorizer which coordinates the Vocabularies and puts them to use\"\"\"        \n",
    "    def __init__(self, source_vocab, target_vocab, max_source_length, max_target_length):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            source_vocab (SequenceVocabulary): maps source words to integers\n",
    "            target_vocab (SequenceVocabulary): maps target words to integers\n",
    "            max_source_length (int): the longest sequence in the source dataset\n",
    "            max_target_length (int): the longest sequence in the target dataset\n",
    "        \"\"\"\n",
    "        self.source_vocab = source_vocab\n",
    "        self.target_vocab = target_vocab\n",
    "        \n",
    "        self.max_source_length = max_source_length\n",
    "        self.max_target_length = max_target_length\n",
    "        \n",
    "\n",
    "    def _vectorize(self, indices, vector_length=-1, mask_index=0):\n",
    "        \"\"\"Vectorize the provided indices\n",
    "        \n",
    "        Args:\n",
    "            indices (list): a list of integers that represent a sequence\n",
    "            vector_length (int): an argument for forcing the length of index vector\n",
    "            mask_index (int): the mask_index to use; almost always 0\n",
    "        \"\"\"\n",
    "        if vector_length < 0:\n",
    "            vector_length = len(indices)\n",
    "        \n",
    "        vector = np.zeros(vector_length, dtype=np.int64)\n",
    "        vector[:len(indices)] = indices\n",
    "        vector[len(indices):] = mask_index\n",
    "\n",
    "        return vector\n",
    "    \n",
    "    def _get_source_indices(self, text):\n",
    "        \"\"\"Return the vectorized source text\n",
    "        \n",
    "        Args:\n",
    "            text (str): the source text; tokens should be separated by spaces\n",
    "        Returns:\n",
    "            indices (list): list of integers representing the text\n",
    "        \"\"\"\n",
    "        indices = [self.source_vocab.begin_seq_index]\n",
    "        indices.extend(self.source_vocab.lookup_token(token) for token in text.split(\" \"))\n",
    "        indices.append(self.source_vocab.end_seq_index)\n",
    "        return indices\n",
    "    \n",
    "    def _get_target_indices(self, text):\n",
    "        \"\"\"Return the vectorized source text\n",
    "        \n",
    "        Args:\n",
    "            text (str): the source text; tokens should be separated by spaces\n",
    "        Returns:\n",
    "            a tuple: (x_indices, y_indices)\n",
    "                x_indices (list): list of integers representing the observations in target decoder \n",
    "                y_indices (list): list of integers representing predictions in target decoder\n",
    "        \"\"\"\n",
    "        indices = [self.target_vocab.lookup_token(token) for token in text.split(\" \")]\n",
    "        x_indices = [self.target_vocab.begin_seq_index] + indices\n",
    "        y_indices = indices + [self.target_vocab.end_seq_index]\n",
    "        return x_indices, y_indices\n",
    "        \n",
    "    def vectorize(self, source_text, target_text, use_dataset_max_lengths=True):\n",
    "        \"\"\"Return the vectorized source and target text\n",
    "        \n",
    "        The vetorized source text is just the a single vector.\n",
    "        The vectorized target text is split into two vectors in a similar style to \n",
    "            the surname modeling in Chapter 7.\n",
    "        At each timestep, the first vector is the observation and the second vector is the target. \n",
    "        \n",
    "        \n",
    "        Args:\n",
    "            source_text (str): text from the source language\n",
    "            target_text (str): text from the target language\n",
    "            use_dataset_max_lengths (bool): whether to use the global max vector lengths\n",
    "        Returns:\n",
    "            The vectorized data point as a dictionary with the keys: \n",
    "                source_vector, target_x_vector, target_y_vector, source_length\n",
    "        \"\"\"\n",
    "        source_vector_length = -1\n",
    "        target_vector_length = -1\n",
    "        \n",
    "        if use_dataset_max_lengths:\n",
    "            source_vector_length = self.max_source_length + 2\n",
    "            target_vector_length = self.max_target_length + 1\n",
    "            \n",
    "        source_indices = self._get_source_indices(source_text)\n",
    "        source_vector = self._vectorize(source_indices, \n",
    "                                        vector_length=source_vector_length, \n",
    "                                        mask_index=self.source_vocab.mask_index)\n",
    "        \n",
    "        target_x_indices, target_y_indices = self._get_target_indices(target_text)\n",
    "        target_x_vector = self._vectorize(target_x_indices,\n",
    "                                        vector_length=target_vector_length,\n",
    "                                        mask_index=self.target_vocab.mask_index)\n",
    "        target_y_vector = self._vectorize(target_y_indices,\n",
    "                                        vector_length=target_vector_length,\n",
    "                                        mask_index=self.target_vocab.mask_index)\n",
    "        return {\"source_vector\": source_vector, \n",
    "                \"target_x_vector\": target_x_vector, \n",
    "                \"target_y_vector\": target_y_vector, \n",
    "                \"source_length\": len(source_indices)}\n",
    "        \n",
    "    @classmethod\n",
    "    def from_dataframe(cls, bitext_df):\n",
    "        \"\"\"Instantiate the vectorizer from the dataset dataframe\n",
    "        \n",
    "        Args:\n",
    "            bitext_df (pandas.DataFrame): the parallel text dataset\n",
    "        Returns:\n",
    "            an instance of the NMTVectorizer\n",
    "        \"\"\"\n",
    "        source_vocab = SequenceVocabulary()\n",
    "        target_vocab = SequenceVocabulary()\n",
    "        \n",
    "        max_source_length = 0\n",
    "        max_target_length = 0\n",
    "\n",
    "        for _, row in bitext_df.iterrows():\n",
    "            source_tokens = row[\"source_language\"].split(\" \")\n",
    "            if len(source_tokens) > max_source_length:\n",
    "                max_source_length = len(source_tokens)\n",
    "            for token in source_tokens:\n",
    "                source_vocab.add_token(token)\n",
    "            \n",
    "            target_tokens = row[\"target_language\"].split(\" \")\n",
    "            if len(target_tokens) > max_target_length:\n",
    "                max_target_length = len(target_tokens)\n",
    "            for token in target_tokens:\n",
    "                target_vocab.add_token(token)\n",
    "            \n",
    "        return cls(source_vocab, target_vocab, max_source_length, max_target_length)\n",
    "\n",
    "    @classmethod\n",
    "    def from_serializable(cls, contents):\n",
    "        source_vocab = SequenceVocabulary.from_serializable(contents[\"source_vocab\"])\n",
    "        target_vocab = SequenceVocabulary.from_serializable(contents[\"target_vocab\"])\n",
    "        \n",
    "        return cls(source_vocab=source_vocab, \n",
    "                   target_vocab=target_vocab, \n",
    "                   max_source_length=contents[\"max_source_length\"], \n",
    "                   max_target_length=contents[\"max_target_length\"])\n",
    "\n",
    "    def to_serializable(self):\n",
    "        return {\"source_vocab\": self.source_vocab.to_serializable(), \n",
    "                \"target_vocab\": self.target_vocab.to_serializable(), \n",
    "                \"max_source_length\": self.max_source_length,\n",
    "                \"max_target_length\": self.max_target_length}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NMTDataset(Dataset):\n",
    "    def __init__(self, text_df, vectorizer):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            surname_df (pandas.DataFrame): the dataset\n",
    "            vectorizer (SurnameVectorizer): vectorizer instatiated from dataset\n",
    "        \"\"\"\n",
    "        self.text_df = text_df\n",
    "        self._vectorizer = vectorizer\n",
    "\n",
    "        self.train_df = self.text_df[self.text_df.split=='train']\n",
    "        self.train_size = len(self.train_df)\n",
    "\n",
    "        self.val_df = self.text_df[self.text_df.split=='val']\n",
    "        self.validation_size = len(self.val_df)\n",
    "\n",
    "        self.test_df = self.text_df[self.text_df.split=='test']\n",
    "        self.test_size = len(self.test_df)\n",
    "\n",
    "        self._lookup_dict = {'train': (self.train_df, self.train_size),\n",
    "                             'val': (self.val_df, self.validation_size),\n",
    "                             'test': (self.test_df, self.test_size)}\n",
    "\n",
    "        self.set_split('train')\n",
    "\n",
    "    @classmethod\n",
    "    def load_dataset_and_make_vectorizer(cls, dataset_csv):\n",
    "        \"\"\"Load dataset and make a new vectorizer from scratch\n",
    "        \n",
    "        Args:\n",
    "            surname_csv (str): location of the dataset\n",
    "        Returns:\n",
    "            an instance of SurnameDataset\n",
    "        \"\"\"\n",
    "        text_df = pd.read_csv(dataset_csv)\n",
    "        train_subset = text_df[text_df.split=='train']\n",
    "        return cls(text_df, NMTVectorizer.from_dataframe(train_subset))\n",
    "\n",
    "    @classmethod\n",
    "    def load_dataset_and_load_vectorizer(cls, dataset_csv, vectorizer_filepath):\n",
    "        \"\"\"Load dataset and the corresponding vectorizer. \n",
    "        Used in the case in the vectorizer has been cached for re-use\n",
    "        \n",
    "        Args:\n",
    "            surname_csv (str): location of the dataset\n",
    "            vectorizer_filepath (str): location of the saved vectorizer\n",
    "        Returns:\n",
    "            an instance of SurnameDataset\n",
    "        \"\"\"\n",
    "        text_df = pd.read_csv(dataset_csv)\n",
    "        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)\n",
    "        return cls(text_df, vectorizer)\n",
    "\n",
    "    @staticmethod\n",
    "    def load_vectorizer_only(vectorizer_filepath):\n",
    "        \"\"\"a static method for loading the vectorizer from file\n",
    "        \n",
    "        Args:\n",
    "            vectorizer_filepath (str): the location of the serialized vectorizer\n",
    "        Returns:\n",
    "            an instance of SurnameVectorizer\n",
    "        \"\"\"\n",
    "        with open(vectorizer_filepath) as fp:\n",
    "            return NMTVectorizer.from_serializable(json.load(fp))\n",
    "\n",
    "    def save_vectorizer(self, vectorizer_filepath):\n",
    "        \"\"\"saves the vectorizer to disk using json\n",
    "        \n",
    "        Args:\n",
    "            vectorizer_filepath (str): the location to save the vectorizer\n",
    "        \"\"\"\n",
    "        with open(vectorizer_filepath, \"w\") as fp:\n",
    "            json.dump(self._vectorizer.to_serializable(), fp)\n",
    "\n",
    "    def get_vectorizer(self):\n",
    "        \"\"\" returns the vectorizer \"\"\"\n",
    "        return self._vectorizer\n",
    "\n",
    "    def set_split(self, split=\"train\"):\n",
    "        self._target_split = split\n",
    "        self._target_df, self._target_size = self._lookup_dict[split]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self._target_size\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        \"\"\"the primary entry point method for PyTorch datasets\n",
    "        \n",
    "        Args:\n",
    "            index (int): the index to the data point \n",
    "        Returns:\n",
    "            a dictionary holding the data point: (x_data, y_target, class_index)\n",
    "        \"\"\"\n",
    "        row = self._target_df.iloc[index]\n",
    "\n",
    "        vector_dict = self._vectorizer.vectorize(row.source_language, row.target_language)\n",
    "\n",
    "        return {\"x_source\": vector_dict[\"source_vector\"], \n",
    "                \"x_target\": vector_dict[\"target_x_vector\"],\n",
    "                \"y_target\": vector_dict[\"target_y_vector\"], \n",
    "                \"x_source_length\": vector_dict[\"source_length\"]}\n",
    "        \n",
    "    def get_num_batches(self, batch_size):\n",
    "        \"\"\"Given a batch size, return the number of batches in the dataset\n",
    "        \n",
    "        Args:\n",
    "            batch_size (int)\n",
    "        Returns:\n",
    "            number of batches in the dataset\n",
    "        \"\"\"\n",
    "        return len(self) // batch_size\n",
    "\n",
    "def generate_nmt_batches(dataset, batch_size, shuffle=True, \n",
    "                            drop_last=True, device=\"cpu\"):\n",
    "    \"\"\"A generator function which wraps the PyTorch DataLoader.  The NMT Version \"\"\"\n",
    "    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,\n",
    "                            shuffle=shuffle, drop_last=drop_last)\n",
    "\n",
    "    for data_dict in dataloader:\n",
    "        lengths = data_dict['x_source_length'].numpy()\n",
    "        sorted_length_indices = lengths.argsort()[::-1].tolist()\n",
    "        \n",
    "        out_data_dict = {}\n",
    "        for name, tensor in data_dict.items():\n",
    "            out_data_dict[name] = data_dict[name][sorted_length_indices].to(device)\n",
    "        yield out_data_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neural Machine Translation Model\n",
    "\n",
    "Components:\n",
    "\n",
    "1. NMTEncoder\n",
    "    - accepts as input a source sequence to be embedded and fed through a bi-directional GRU\n",
    "2. NMTDecoder\n",
    "    - using the encoder state and attention, the decoder generates a new sequence\n",
    "    - the ground truth target sequence is used as input to the decoder at each time step\n",
    "    - an alternative formulation would allow some of the decoder's own choices to be used as input\n",
    "    - this is referred to as curriculum learning, learning to search\n",
    "        - TODO: Look up references for this.  I believe Bengio has a paper from the image captioning competitions. Hal Daume has tons on this and is the main NLP guy for it. \n",
    "3. NMTModel\n",
    "    - Combines the encoder and decoder into a single class. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NMTEncoder(nn.Module):\n",
    "    def __init__(self, num_embeddings, embedding_size, rnn_hidden_size):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            num_embeddings (int): number of embeddings is the size of source vocabulary\n",
    "            embedding_size (int): size of the embedding vectors\n",
    "            rnn_hidden_size (int): size of the RNN hidden state vectors \n",
    "        \"\"\"\n",
    "        super(NMTEncoder, self).__init__()\n",
    "    \n",
    "        self.source_embedding = nn.Embedding(num_embeddings, embedding_size, padding_idx=0)\n",
    "        self.birnn = nn.GRU(embedding_size, rnn_hidden_size, bidirectional=True, batch_first=True)\n",
    "    \n",
    "    def forward(self, x_source, x_lengths):\n",
    "        \"\"\"The forward pass of the model\n",
    "        \n",
    "        Args:\n",
    "            x_source (torch.Tensor): the input data tensor.\n",
    "                x_source.shape is (batch, seq_size)\n",
    "            x_lengths (torch.Tensor): a vector of lengths for each item in the batch\n",
    "        Returns:\n",
    "            a tuple: x_unpacked (torch.Tensor), x_birnn_h (torch.Tensor)\n",
    "                x_unpacked.shape = (batch, seq_size, rnn_hidden_size * 2)\n",
    "                x_birnn_h.shape = (batch, rnn_hidden_size * 2)\n",
    "        \"\"\"\n",
    "        x_embedded = self.source_embedding(x_source)\n",
    "        # create PackedSequence; x_packed.data.shape=(number_items, embeddign_size)\n",
    "        x_packed = pack_padded_sequence(x_embedded, x_lengths.detach().cpu().numpy(), \n",
    "                                        batch_first=True)\n",
    "        \n",
    "        # x_birnn_h.shape = (num_rnn, batch_size, feature_size)\n",
    "        x_birnn_out, x_birnn_h  = self.birnn(x_packed)\n",
    "        # permute to (batch_size, num_rnn, feature_size)\n",
    "        x_birnn_h = x_birnn_h.permute(1, 0, 2)\n",
    "        \n",
    "        # flatten features; reshape to (batch_size, num_rnn * feature_size)\n",
    "        #  (recall: -1 takes the remaining positions, \n",
    "        #           flattening the two RNN hidden vectors into 1)\n",
    "        x_birnn_h = x_birnn_h.contiguous().view(x_birnn_h.size(0), -1)\n",
    "        \n",
    "        x_unpacked, _ = pad_packed_sequence(x_birnn_out, batch_first=True)\n",
    "        \n",
    "        return x_unpacked, x_birnn_h\n",
    "\n",
    "def verbose_attention(encoder_state_vectors, query_vector):\n",
    "    \"\"\"A descriptive version of the neural attention mechanism \n",
    "    \n",
    "    Args:\n",
    "        encoder_state_vectors (torch.Tensor): 3dim tensor from bi-GRU in encoder\n",
    "        query_vector (torch.Tensor): hidden state in decoder GRU\n",
    "    Returns:\n",
    "        \n",
    "    \"\"\"\n",
    "    batch_size, num_vectors, vector_size = encoder_state_vectors.size()\n",
    "    vector_scores = torch.sum(encoder_state_vectors * query_vector.view(batch_size, 1, vector_size), \n",
    "                              dim=2)\n",
    "    vector_probabilities = F.softmax(vector_scores, dim=1)\n",
    "    weighted_vectors = encoder_state_vectors * vector_probabilities.view(batch_size, num_vectors, 1)\n",
    "    context_vectors = torch.sum(weighted_vectors, dim=1)\n",
    "    return context_vectors, vector_probabilities, vector_scores\n",
    "\n",
    "def terse_attention(encoder_state_vectors, query_vector):\n",
    "    \"\"\"A shorter and more optimized version of the neural attention mechanism\n",
    "    \n",
    "    Args:\n",
    "        encoder_state_vectors (torch.Tensor): 3dim tensor from bi-GRU in encoder\n",
    "        query_vector (torch.Tensor): hidden state\n",
    "    \"\"\"\n",
    "    vector_scores = torch.matmul(encoder_state_vectors, query_vector.unsqueeze(dim=2)).squeeze()\n",
    "    vector_probabilities = F.softmax(vector_scores, dim=-1)\n",
    "    context_vectors = torch.matmul(encoder_state_vectors.transpose(-2, -1), \n",
    "                                   vector_probabilities.unsqueeze(dim=2)).squeeze()\n",
    "    return context_vectors, vector_probabilities\n",
    "\n",
    "\n",
    "\n",
    "class NMTDecoder(nn.Module):\n",
    "    def __init__(self, num_embeddings, embedding_size, rnn_hidden_size, bos_index):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            num_embeddings (int): number of embeddings is also the number of \n",
    "                unique words in target vocabulary \n",
    "            embedding_size (int): the embedding vector size\n",
    "            rnn_hidden_size (int): size of the hidden rnn state\n",
    "            bos_index(int): begin-of-sequence index\n",
    "        \"\"\"\n",
    "        super(NMTDecoder, self).__init__()\n",
    "        self._rnn_hidden_size = rnn_hidden_size\n",
    "        self.target_embedding = nn.Embedding(num_embeddings=num_embeddings, \n",
    "                                             embedding_dim=embedding_size, \n",
    "                                             padding_idx=0)\n",
    "        self.gru_cell = nn.GRUCell(embedding_size + rnn_hidden_size, \n",
    "                                   rnn_hidden_size)\n",
    "        self.hidden_map = nn.Linear(rnn_hidden_size, rnn_hidden_size)\n",
    "        self.classifier = nn.Linear(rnn_hidden_size * 2, num_embeddings)\n",
    "        self.bos_index = bos_index\n",
    "        self._sampling_temperature = 3\n",
    "    \n",
    "    def _init_indices(self, batch_size):\n",
    "        \"\"\" return the BEGIN-OF-SEQUENCE index vector \"\"\"\n",
    "        return torch.ones(batch_size, dtype=torch.int64) * self.bos_index\n",
    "    \n",
    "    def _init_context_vectors(self, batch_size):\n",
    "        \"\"\" return a zeros vector for initializing the context \"\"\"\n",
    "        return torch.zeros(batch_size, self._rnn_hidden_size)\n",
    "            \n",
    "    def forward(self, encoder_state, initial_hidden_state, target_sequence, sample_probability=0.0):\n",
    "        \"\"\"The forward pass of the model\n",
    "        \n",
    "        Args:\n",
    "            encoder_state (torch.Tensor): the output of the NMTEncoder\n",
    "            initial_hidden_state (torch.Tensor): The last hidden state in the  NMTEncoder\n",
    "            target_sequence (torch.Tensor): the target text data tensor\n",
    "            sample_probability (float): the schedule sampling parameter\n",
    "                probabilty of using model's predictions at each decoder step\n",
    "        Returns:\n",
    "            output_vectors (torch.Tensor): prediction vectors at each output step\n",
    "        \"\"\"\n",
    "        if target_sequence is None:\n",
    "            sample_probability = 1.0\n",
    "        else:\n",
    "            # We are making an assumption there: The batch is on first\n",
    "            # The input is (Batch, Seq)\n",
    "            # We want to iterate over sequence so we permute it to (S, B)\n",
    "            target_sequence = target_sequence.permute(1, 0)\n",
    "            output_sequence_size = target_sequence.size(0)\n",
    "        \n",
    "        # use the provided encoder hidden state as the initial hidden state\n",
    "        h_t = self.hidden_map(initial_hidden_state)\n",
    "        \n",
    "        batch_size = encoder_state.size(0)\n",
    "        # initialize context vectors to zeros\n",
    "        context_vectors = self._init_context_vectors(batch_size)\n",
    "        # initialize first y_t word as BOS\n",
    "        y_t_index = self._init_indices(batch_size)\n",
    "        \n",
    "        h_t = h_t.to(encoder_state.device)\n",
    "        y_t_index = y_t_index.to(encoder_state.device)\n",
    "        context_vectors = context_vectors.to(encoder_state.device)\n",
    "\n",
    "        output_vectors = []\n",
    "        self._cached_p_attn = []\n",
    "        self._cached_ht = []\n",
    "        self._cached_decoder_state = encoder_state.cpu().detach().numpy()\n",
    "        \n",
    "        for i in range(output_sequence_size):\n",
    "            # Schedule sampling is whe\n",
    "            use_sample = np.random.random() < sample_probability\n",
    "            if not use_sample:\n",
    "                y_t_index = target_sequence[i]\n",
    "                \n",
    "            # Step 1: Embed word and concat with previous context\n",
    "            y_input_vector = self.target_embedding(y_t_index)\n",
    "            rnn_input = torch.cat([y_input_vector, context_vectors], dim=1)\n",
    "            \n",
    "            # Step 2: Make a GRU step, getting a new hidden vector\n",
    "            h_t = self.gru_cell(rnn_input, h_t)\n",
    "            self._cached_ht.append(h_t.cpu().detach().numpy())\n",
    "            \n",
    "            # Step 3: Use the current hidden to attend to the encoder state\n",
    "            context_vectors, p_attn, _ = verbose_attention(encoder_state_vectors=encoder_state, \n",
    "                                                           query_vector=h_t)\n",
    "            \n",
    "            # auxillary: cache the attention probabilities for visualization\n",
    "            self._cached_p_attn.append(p_attn.cpu().detach().numpy())\n",
    "            \n",
    "            # Step 4: Use the current hidden and context vectors to make a prediction to the next word\n",
    "            prediction_vector = torch.cat((context_vectors, h_t), dim=1)\n",
    "            score_for_y_t_index = self.classifier(F.dropout(prediction_vector, 0.3))\n",
    "            \n",
    "            if use_sample:\n",
    "                p_y_t_index = F.softmax(score_for_y_t_index * self._sampling_temperature, dim=1)\n",
    "                # _, y_t_index = torch.max(p_y_t_index, 1)\n",
    "                y_t_index = torch.multinomial(p_y_t_index, 1).squeeze()\n",
    "            \n",
    "            # auxillary: collect the prediction scores\n",
    "            output_vectors.append(score_for_y_t_index)\n",
    "            \n",
    "        output_vectors = torch.stack(output_vectors).permute(1, 0, 2)\n",
    "        \n",
    "        return output_vectors\n",
    "    \n",
    "    \n",
    "class NMTModel(nn.Module):\n",
    "    \"\"\" The Neural Machine Translation Model \"\"\"\n",
    "    def __init__(self, source_vocab_size, source_embedding_size, \n",
    "                 target_vocab_size, target_embedding_size, encoding_size, \n",
    "                 target_bos_index):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            source_vocab_size (int): number of unique words in source language\n",
    "            source_embedding_size (int): size of the source embedding vectors\n",
    "            target_vocab_size (int): number of unique words in target language\n",
    "            target_embedding_size (int): size of the target embedding vectors\n",
    "            encoding_size (int): the size of the encoder RNN.  \n",
    "        \"\"\"\n",
    "        super(NMTModel, self).__init__()\n",
    "        self.encoder = NMTEncoder(num_embeddings=source_vocab_size, \n",
    "                                  embedding_size=source_embedding_size,\n",
    "                                  rnn_hidden_size=encoding_size)\n",
    "        decoding_size = encoding_size * 2\n",
    "        self.decoder = NMTDecoder(num_embeddings=target_vocab_size, \n",
    "                                  embedding_size=target_embedding_size, \n",
    "                                  rnn_hidden_size=decoding_size,\n",
    "                                  bos_index=target_bos_index)\n",
    "    \n",
    "    def forward(self, x_source, x_source_lengths, target_sequence, sample_probability=0.0):\n",
    "        \"\"\"The forward pass of the model\n",
    "        \n",
    "        Args:\n",
    "            x_source (torch.Tensor): the source text data tensor. \n",
    "                x_source.shape should be (batch, vectorizer.max_source_length)\n",
    "            x_source_lengths torch.Tensor): the length of the sequences in x_source \n",
    "            target_sequence (torch.Tensor): the target text data tensor\n",
    "            sample_probability (float): the schedule sampling parameter\n",
    "                probabilty of using model's predictions at each decoder step\n",
    "        Returns:\n",
    "            decoded_states (torch.Tensor): prediction vectors at each output step\n",
    "        \"\"\"\n",
    "        encoder_state, final_hidden_states = self.encoder(x_source, x_source_lengths)\n",
    "        decoded_states = self.decoder(encoder_state=encoder_state, \n",
    "                                      initial_hidden_state=final_hidden_states, \n",
    "                                      target_sequence=target_sequence, \n",
    "                                      sample_probability=sample_probability)\n",
    "        return decoded_states"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Routine and Bookkeeping Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed_everywhere(seed, cuda):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if cuda:\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "def handle_dirs(dirpath):\n",
    "    if not os.path.exists(dirpath):\n",
    "        os.makedirs(dirpath)\n",
    "\n",
    "def make_train_state(args):\n",
    "    return {'stop_early': False,\n",
    "            'early_stopping_step': 0,\n",
    "            'early_stopping_best_val': 1e8,\n",
    "            'learning_rate': args.learning_rate,\n",
    "            'epoch_index': 0,\n",
    "            'train_loss': [],\n",
    "            'train_acc': [],\n",
    "            'val_loss': [],\n",
    "            'val_acc': [],\n",
    "            'test_loss': -1,\n",
    "            'test_acc': -1,\n",
    "            'model_filename': args.model_state_file}\n",
    "\n",
    "def update_train_state(args, model, train_state):\n",
    "    \"\"\"Handle the training state updates.\n",
    "    Components:\n",
    "     - Early Stopping: Prevent overfitting.\n",
    "     - Model Checkpoint: Model is saved if the model is better\n",
    "    \n",
    "    :param args: main arguments\n",
    "    :param model: model to train\n",
    "    :param train_state: a dictionary representing the training state values\n",
    "    :returns:\n",
    "        a new train_state\n",
    "    \"\"\"\n",
    "\n",
    "    # Save one model at least\n",
    "    if train_state['epoch_index'] == 0:\n",
    "        torch.save(model.state_dict(), train_state['model_filename'])\n",
    "        train_state['stop_early'] = False\n",
    "\n",
    "    # Save model if performance improved\n",
    "    elif train_state['epoch_index'] >= 1:\n",
    "        loss_tm1, loss_t = train_state['val_loss'][-2:]\n",
    "         \n",
    "        # If loss worsened\n",
    "        if loss_t >= loss_tm1:\n",
    "            # Update step\n",
    "            train_state['early_stopping_step'] += 1\n",
    "        # Loss decreased\n",
    "        else:\n",
    "            # Save the best model\n",
    "            if loss_t < train_state['early_stopping_best_val']:\n",
    "                torch.save(model.state_dict(), train_state['model_filename'])\n",
    "                train_state['early_stopping_best_val'] = loss_t\n",
    "\n",
    "            # Reset early stopping step\n",
    "            train_state['early_stopping_step'] = 0\n",
    "\n",
    "        # Stop early ?\n",
    "        train_state['stop_early'] = \\\n",
    "            train_state['early_stopping_step'] >= args.early_stopping_criteria\n",
    "\n",
    "    return train_state\n",
    "\n",
    "def normalize_sizes(y_pred, y_true):\n",
    "    \"\"\"Normalize tensor sizes\n",
    "    \n",
    "    Args:\n",
    "        y_pred (torch.Tensor): the output of the model\n",
    "            If a 3-dimensional tensor, reshapes to a matrix\n",
    "        y_true (torch.Tensor): the target predictions\n",
    "            If a matrix, reshapes to be a vector\n",
    "    \"\"\"\n",
    "    if len(y_pred.size()) == 3:\n",
    "        y_pred = y_pred.contiguous().view(-1, y_pred.size(2))\n",
    "    if len(y_true.size()) == 2:\n",
    "        y_true = y_true.contiguous().view(-1)\n",
    "    return y_pred, y_true\n",
    "\n",
    "def compute_accuracy(y_pred, y_true, mask_index):\n",
    "    y_pred, y_true = normalize_sizes(y_pred, y_true)\n",
    "\n",
    "    _, y_pred_indices = y_pred.max(dim=1)\n",
    "    \n",
    "    correct_indices = torch.eq(y_pred_indices, y_true).float()\n",
    "    valid_indices = torch.ne(y_true, mask_index).float()\n",
    "    \n",
    "    n_correct = (correct_indices * valid_indices).sum().item()\n",
    "    n_valid = valid_indices.sum().item()\n",
    "\n",
    "    return n_correct / n_valid * 100\n",
    "\n",
    "def sequence_loss(y_pred, y_true, mask_index):\n",
    "    y_pred, y_true = normalize_sizes(y_pred, y_true)\n",
    "    return F.cross_entropy(y_pred, y_true, ignore_index=mask_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expanded filepaths: \n",
      "\tmodel_storage/ch8/nmt_luong_sampling/vectorizer.json\n",
      "\tmodel_storage/ch8/nmt_luong_sampling/model.pth\n",
      "Using CUDA: True\n"
     ]
    }
   ],
   "source": [
    "args = Namespace(dataset_csv=\"data/nmt/simplest_eng_fra.csv\",\n",
    "                 vectorizer_file=\"vectorizer.json\",\n",
    "                 model_state_file=\"model.pth\",\n",
    "                 save_dir=\"model_storage/ch8/nmt_luong_sampling\",\n",
    "                 reload_from_files=False,\n",
    "                 expand_filepaths_to_save_dir=True,\n",
    "                 cuda=True,\n",
    "                 seed=1337,\n",
    "                 learning_rate=5e-4,\n",
    "                 batch_size=32,\n",
    "                 num_epochs=100,\n",
    "                 early_stopping_criteria=5,              \n",
    "                 source_embedding_size=24, \n",
    "                 target_embedding_size=24,\n",
    "                 encoding_size=32,\n",
    "                 catch_keyboard_interrupt=True)\n",
    "\n",
    "if args.expand_filepaths_to_save_dir:\n",
    "    args.vectorizer_file = os.path.join(args.save_dir,\n",
    "                                        args.vectorizer_file)\n",
    "\n",
    "    args.model_state_file = os.path.join(args.save_dir,\n",
    "                                         args.model_state_file)\n",
    "    \n",
    "    print(\"Expanded filepaths: \")\n",
    "    print(\"\\t{}\".format(args.vectorizer_file))\n",
    "    print(\"\\t{}\".format(args.model_state_file))\n",
    "    \n",
    "# Check CUDA\n",
    "if not torch.cuda.is_available():\n",
    "    args.cuda = False\n",
    "\n",
    "args.device = torch.device(\"cuda\" if args.cuda else \"cpu\")\n",
    "    \n",
    "print(\"Using CUDA: {}\".format(args.cuda))\n",
    "\n",
    "# Set seed for reproducibility\n",
    "set_seed_everywhere(args.seed, args.cuda)\n",
    "\n",
    "# handle dirs\n",
    "handle_dirs(args.save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.reload_from_files and os.path.exists(args.vectorizer_file):\n",
    "    # training from a checkpoint\n",
    "    dataset = NMTDataset.load_dataset_and_load_vectorizer(args.dataset_csv,\n",
    "                                                          args.vectorizer_file)\n",
    "else:\n",
    "    # create dataset and vectorizer\n",
    "    dataset = NMTDataset.load_dataset_and_make_vectorizer(args.dataset_csv)\n",
    "    dataset.save_vectorizer(args.vectorizer_file)\n",
    "\n",
    "vectorizer = dataset.get_vectorizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New model\n"
     ]
    }
   ],
   "source": [
    "model = NMTModel(source_vocab_size=len(vectorizer.source_vocab), \n",
    "                 source_embedding_size=args.source_embedding_size, \n",
    "                 target_vocab_size=len(vectorizer.target_vocab),\n",
    "                 target_embedding_size=args.target_embedding_size, \n",
    "                 encoding_size=args.encoding_size,\n",
    "                 target_bos_index=vectorizer.target_vocab.begin_seq_index)\n",
    "\n",
    "if args.reload_from_files and os.path.exists(args.model_state_file):\n",
    "    model.load_state_dict(torch.load(args.model_state_file))\n",
    "    print(\"Reloaded model\")\n",
    "else:\n",
    "    print(\"New model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c582817f88ce4b908f29a905792909b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', style=ProgressStyle(description_width='ini…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c724cf67edcb4c71b4e94b0fe7d8daab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='split=train', max=285, style=ProgressStyle(description_width=…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "34ab5dd08b10456399433b53fb60ac4d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='split=val', max=61, style=ProgressStyle(description_width='in…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exiting loop\n"
     ]
    }
   ],
   "source": [
    "model = model.to(args.device)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,\n",
    "                                           mode='min', factor=0.5,\n",
    "                                           patience=1)\n",
    "mask_index = vectorizer.target_vocab.mask_index\n",
    "train_state = make_train_state(args)\n",
    "\n",
    "epoch_bar = tqdm_notebook(desc='training routine', \n",
    "                          total=args.num_epochs,\n",
    "                          position=0)\n",
    "\n",
    "dataset.set_split('train')\n",
    "train_bar = tqdm_notebook(desc='split=train',\n",
    "                          total=dataset.get_num_batches(args.batch_size), \n",
    "                          position=1, \n",
    "                          leave=True)\n",
    "dataset.set_split('val')\n",
    "val_bar = tqdm_notebook(desc='split=val',\n",
    "                        total=dataset.get_num_batches(args.batch_size), \n",
    "                        position=1, \n",
    "                        leave=True)\n",
    "\n",
    "try:\n",
    "    for epoch_index in range(args.num_epochs):\n",
    "        sample_probability = (20 + epoch_index) / args.num_epochs\n",
    "        \n",
    "        train_state['epoch_index'] = epoch_index\n",
    "\n",
    "        # Iterate over training dataset\n",
    "\n",
    "        # setup: batch generator, set loss and acc to 0, set train mode on\n",
    "        dataset.set_split('train')\n",
    "        batch_generator = generate_nmt_batches(dataset, \n",
    "                                               batch_size=args.batch_size, \n",
    "                                               device=args.device)\n",
    "        running_loss = 0.0\n",
    "        running_acc = 0.0\n",
    "        model.train()\n",
    "        \n",
    "        for batch_index, batch_dict in enumerate(batch_generator):\n",
    "            # the training routine is these 5 steps:\n",
    "\n",
    "            # --------------------------------------    \n",
    "            # step 1. zero the gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # step 2. compute the output\n",
    "            y_pred = model(batch_dict['x_source'], \n",
    "                           batch_dict['x_source_length'], \n",
    "                           batch_dict['x_target'],\n",
    "                           sample_probability=sample_probability)\n",
    "\n",
    "            # step 3. compute the loss\n",
    "            loss = sequence_loss(y_pred, batch_dict['y_target'], mask_index)\n",
    "\n",
    "            # step 4. use loss to produce gradients\n",
    "            loss.backward()\n",
    "\n",
    "            # step 5. use optimizer to take gradient step\n",
    "            optimizer.step()\n",
    "\n",
    "            # -----------------------------------------\n",
    "            # compute the running loss and running accuracy\n",
    "            running_loss += (loss.item() - running_loss) / (batch_index + 1)\n",
    "\n",
    "            acc_t = compute_accuracy(y_pred, batch_dict['y_target'], mask_index)\n",
    "            running_acc += (acc_t - running_acc) / (batch_index + 1)\n",
    "\n",
    "            # update bar\n",
    "            train_bar.set_postfix(loss=running_loss, acc=running_acc, \n",
    "                                  epoch=epoch_index)\n",
    "            train_bar.update()\n",
    "\n",
    "        train_state['train_loss'].append(running_loss)\n",
    "        train_state['train_acc'].append(running_acc)\n",
    "\n",
    "        # Iterate over val dataset\n",
    "\n",
    "        # setup: batch generator, set loss and acc to 0; set eval mode on\n",
    "        dataset.set_split('val')\n",
    "        batch_generator = generate_nmt_batches(dataset, \n",
    "                                               batch_size=args.batch_size, \n",
    "                                               device=args.device)\n",
    "        running_loss = 0.\n",
    "        running_acc = 0.\n",
    "        model.eval()\n",
    "\n",
    "        for batch_index, batch_dict in enumerate(batch_generator):\n",
    "            # compute the output\n",
    "            y_pred = model(batch_dict['x_source'], \n",
    "                           batch_dict['x_source_length'], \n",
    "                           batch_dict['x_target'],\n",
    "                           sample_probability=sample_probability)\n",
    "\n",
    "            # step 3. compute the loss\n",
    "            loss = sequence_loss(y_pred, batch_dict['y_target'], mask_index)\n",
    "\n",
    "            # compute the running loss and accuracy\n",
    "            running_loss += (loss.item() - running_loss) / (batch_index + 1)\n",
    "            \n",
    "            acc_t = compute_accuracy(y_pred, batch_dict['y_target'], mask_index)\n",
    "            running_acc += (acc_t - running_acc) / (batch_index + 1)\n",
    "            \n",
    "            # Update bar\n",
    "            val_bar.set_postfix(loss=running_loss, acc=running_acc, \n",
    "                            epoch=epoch_index)\n",
    "            val_bar.update()\n",
    "\n",
    "        train_state['val_loss'].append(running_loss)\n",
    "        train_state['val_acc'].append(running_acc)\n",
    "\n",
    "        train_state = update_train_state(args=args, model=model, \n",
    "                                         train_state=train_state)\n",
    "\n",
    "        scheduler.step(train_state['val_loss'][-1])\n",
    "\n",
    "        if train_state['stop_early']:\n",
    "            break\n",
    "        \n",
    "        train_bar.n = 0\n",
    "        val_bar.n = 0\n",
    "        epoch_bar.set_postfix(best_val=train_state['early_stopping_best_val'])\n",
    "        epoch_bar.update()\n",
    "        \n",
    "except KeyboardInterrupt:\n",
    "    print(\"Exiting loop\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk.translate import bleu_score\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "chencherry = bleu_score.SmoothingFunction()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sentence_from_indices(indices, vocab, strict=True, return_string=True):\n",
    "    ignore_indices = set([vocab.mask_index, vocab.begin_seq_index, vocab.end_seq_index])\n",
    "    out = []\n",
    "    for index in indices:\n",
    "        if index == vocab.begin_seq_index and strict:\n",
    "            continue\n",
    "        elif index == vocab.end_seq_index and strict:\n",
    "            break\n",
    "        else:\n",
    "            out.append(vocab.lookup_index(index))\n",
    "    if return_string:\n",
    "        return \" \".join(out)\n",
    "    else:\n",
    "        return out\n",
    "    \n",
    "class NMTSampler:\n",
    "    def __init__(self, vectorizer, model):\n",
    "        self.vectorizer = vectorizer\n",
    "        self.model = model\n",
    "    \n",
    "    def apply_to_batch(self, batch_dict):\n",
    "        self._last_batch = batch_dict\n",
    "        y_pred = self.model(x_source=batch_dict['x_source'], \n",
    "                            x_source_lengths=batch_dict['x_source_length'], \n",
    "                            target_sequence=batch_dict['x_target'])\n",
    "        self._last_batch['y_pred'] = y_pred\n",
    "        \n",
    "        attention_batched = np.stack(self.model.decoder._cached_p_attn).transpose(1, 0, 2)\n",
    "        self._last_batch['attention'] = attention_batched\n",
    "        \n",
    "    def _get_source_sentence(self, index, return_string=True):\n",
    "        indices = self._last_batch['x_source'][index].cpu().detach().numpy()\n",
    "        vocab = self.vectorizer.source_vocab\n",
    "        return sentence_from_indices(indices, vocab, return_string=return_string)\n",
    "\n",
    "    def _get_reference_sentence(self, index, return_string=True):\n",
    "        indices = self._last_batch['y_target'][index].cpu().detach().numpy()\n",
    "        vocab = self.vectorizer.target_vocab\n",
    "        return sentence_from_indices(indices, vocab, return_string=return_string)\n",
    "    \n",
    "    def _get_sampled_sentence(self, index, return_string=True):\n",
    "        _, all_indices = torch.max(self._last_batch['y_pred'], dim=2)\n",
    "        sentence_indices = all_indices[index].cpu().detach().numpy()\n",
    "        vocab = self.vectorizer.target_vocab\n",
    "        return sentence_from_indices(sentence_indices, vocab, return_string=return_string)\n",
    "\n",
    "    def get_ith_item(self, index, return_string=True):\n",
    "        output = {\"source\": self._get_source_sentence(index, return_string=return_string), \n",
    "                  \"reference\": self._get_reference_sentence(index, return_string=return_string), \n",
    "                  \"sampled\": self._get_sampled_sentence(index, return_string=return_string),\n",
    "                  \"attention\": self._last_batch['attention'][index]}\n",
    "        \n",
    "        reference = output['reference']\n",
    "        hypothesis = output['sampled']\n",
    "        \n",
    "        if not return_string:\n",
    "            reference = \" \".join(reference)\n",
    "            hypothesis = \" \".join(hypothesis)\n",
    "        \n",
    "        output['bleu-4'] = bleu_score.sentence_bleu(references=[reference],\n",
    "                                                    hypothesis=hypothesis,\n",
    "                                                    smoothing_function=chencherry.method1)\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.eval().to(args.device)\n",
    "\n",
    "sampler = NMTSampler(vectorizer, model)\n",
    "\n",
    "dataset.set_split('test')\n",
    "batch_generator = generate_nmt_batches(dataset, \n",
    "                                       batch_size=args.batch_size, \n",
    "                                       device=args.device)\n",
    "\n",
    "test_results = []\n",
    "for batch_dict in batch_generator:\n",
    "    sampler.apply_to_batch(batch_dict)\n",
    "    for i in range(args.batch_size):\n",
    "        test_results.append(sampler.get_ith_item(i, False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.28705687208416497, 0.26730098289088333)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADiZJREFUeJzt3X2sJuVdxvHvVSjWl1YoHAkB1qV2K66mFl0JpolRaA1SBUxJA7EGEnRjxVrTJhatf9SXRNCk2EQSs5aG1WgBUQPWqkG6pGlTqEt5E0jLgjSCFLYVrNVY3fbnH88Ah2UPzzzvc+79fpKTnZlnzjO/nXPOde5z33PPpKqQJG1+L1t1AZKk+TDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY04cpkHO+6442rr1q3LPKQkbXp33nnnl6pqbdx+Sw30rVu3snfv3mUeUpI2vSRf6LOfXS6S1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktSIpc4U1bBsvfxvn1t+9Iq3rLASSfNgC12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgfcDEwPnRC0rRsoUtSIwx0SWqEgS5JjTDQJakRDorqRRyYlTYnW+iS1AgDXZIaYaBLUiMMdElqRO9AT3JEkruSfLRbPyXJHUn2Jbk+yVGLK1OSNM4kLfR3AQ+uW78SuKqqXgs8DVw6z8IkSZPpFehJTgLeAnyoWw9wJnBjt8tu4PxFFChJ6qdvC/0PgF8FvtGtHws8U1UHuvXHgBPnXJskaQJjAz3JTwJPVdWd0xwgyc4ke5Ps3b9//zRvIUnqoU8L/Y3AuUkeBa5j1NXyQeDoJM/OND0JePxQn1xVu6pqR1XtWFtbm0PJkqRDGRvoVfVrVXVSVW0FLgQ+XlU/A+wBLuh2uxi4aWFVSpLGmuU69PcC706yj1Gf+jXzKUmSNI2Jbs5VVbcBt3XLjwCnz78kSdI0nCkqSY0w0CWpEQa6JDXCB1xscvN6GMX695G0OdlCl6RGGOiS1AgDXZIaYaBLUiMOu0FRn2g/mVnPl+dbWh5b6JLUCANdkhphoEtSI5rtQ+/Td2v/7vD4NZGmZwtdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1IhmJxbp0GZ5MlHfST8+/UhaDVvoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEY4sWiTmPRJPkOc3LPopxH5tCMd7myhS1IjDHRJaoSBLkmNsA+9s4z+V/t4n+e5kOZvbAs9ySuSfCbJPUnuT/Kb3fZTktyRZF+S65MctfhyJUkb6dPl8jXgzKr6fuANwNlJzgCuBK6qqtcCTwOXLq5MSdI4YwO9Rr7arb68+yjgTODGbvtu4PyFVChJ6qXXoGiSI5LcDTwF3AI8DDxTVQe6XR4DTlxMiZKkPnoNilbV14E3JDka+Gvg1L4HSLIT2AmwZcuWaWrsbRGTaQ5+zyEP4A1xMtG8OClJGm+iyxar6hlgD/DDwNFJnv2FcBLw+Aafs6uqdlTVjrW1tZmKlSRtrM9VLmtdy5wk3wy8GXiQUbBf0O12MXDTooqUJI3Xp8vlBGB3kiMY/QK4oao+muQB4LokvwPcBVyzwDolSWOMDfSquhc47RDbHwFOX0RRemkt95VLmp5T/yWpEQa6JDXCQJekRnhzrgEbcl/5UK7PH8I5WuU17F4/r/VsoUtSIwx0SWqEgS5JjTDQJakRDoougANVklbBFrokNcJAl6RGGOiS1IhN34feZ2LJKiefDGHiy9AN+Rw5HvI8z8Xw2UKXpEYY6JLUCANdkhphoEtSIzb9oOgiDHmQTsM2lLtQ6vBkC12SGmGgS1IjDHRJaoR96BNycsXhy6+9hs4WuiQ1wkCXpEYY6JLUCPvQNRdeuz+effBaNFvoktQIA12SGmGgS1IjDHRJaoSDotIUHATWENlCl6RGGOiS1AgDXZIaYR+6DitO7lHLxrbQk5ycZE+SB5Lcn+Rd3fZXJ7klyUPdv8csvlxJ0kb6dLkcAN5TVduBM4DLkmwHLgduraptwK3duiRpRcYGelU9UVWf7Zb/E3gQOBE4D9jd7bYbOH9RRUqSxptoUDTJVuA04A7g+Kp6onvpi8Dxc61MkjSR3oOiSb4N+EvgV6rqK0mee62qKklt8Hk7gZ0AW7Zsma1aqREOzmoRerXQk7ycUZj/WVX9Vbf5ySQndK+fADx1qM+tql1VtaOqdqytrc2jZknSIfS5yiXANcCDVfWBdS/dDFzcLV8M3DT/8iRJffXpcnkj8LPAfUnu7rb9OnAFcEOSS4EvAG9bTImSpD7GBnpVfRLIBi+fNd9yNpd53aDJGz0d2iznpU8f9Ubvv+yvx6S12ueujTj1X5IaYaBLUiMMdElqhDfnWhH7zSXNmy10SWqEgS5JjTDQJakRBrokNcJBUTVvlQPQkx7bCUSahS10SWqEgS5JjTDQJakR9qFLm4yT0rQRW+iS1AgDXZIaYaBLUiMMdElqhIOi2nQcFJyPWZ7qtIxja3K20CWpEQa6JDXCQJekRtiHLg3UIvqvHX9omy10SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiOcWCQ1whteyRa6JDXCQJekRhjoktQI+9ClxvW5IZf9720Y20JP8uEkTyX553XbXp3kliQPdf8es9gyJUnj9OlyuRY4+6BtlwO3VtU24NZuXZK0QmMDvao+Afz7QZvPA3Z3y7uB8+dclyRpQtMOih5fVU90y18Ejp9TPZKkKc18lUtVFVAbvZ5kZ5K9Sfbu379/1sNJkjYwbaA/meQEgO7fpzbasap2VdWOqtqxtrY25eEkSeNMG+g3Axd3yxcDN82nHEnStPpctvgR4NPAdyd5LMmlwBXAm5M8BLypW5ckrdDYiUVVddEGL50151okzUmfyUSLPu4yJig5IeqFnPovSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoRPLJK0EE76WT5b6JLUCANdkhphoEtSI+xDlzSTVd0ITC9mC12SGmGgS1IjDHRJaoSBLkmN2JSDog7CSIvT5+dr0p9Bf2aXwxa6JDXCQJekRhjoktSITdOHbh+cJL00W+iS1AgDXZIaYaBLUiMMdElqxKYZFJWkvg7XpyXZQpekRhjoktQIA12SGmEfuqSV6jNpcH0/+Eb7z2vy4bz631fRj28LXZIaYaBLUiMMdElqhH3okgZvlv7xw+ma9Jla6EnOTvK5JPuSXD6voiRJk5s60JMcAVwN/ASwHbgoyfZ5FSZJmswsLfTTgX1V9UhV/S9wHXDefMqSJE1qlkA/EfjXdeuPddskSSuw8EHRJDuBnd3qV5N8boq3OQ740vyqmpuh1gXWNo2h1gXDrW2odcEhasuVk73BpPv3fJ9pztl39tlplkB/HDh53fpJ3bYXqKpdwK4ZjkOSvVW1Y5b3WISh1gXWNo2h1gXDrW2odcFwa1tkXbN0ufwTsC3JKUmOAi4Ebp5PWZKkSU3dQq+qA0l+CfgH4Ajgw1V1/9wqkyRNZKY+9Kr6GPCxOdXyUmbqslmgodYF1jaNodYFw61tqHXBcGtbWF2pqkW9tyRpibyXiyQ1YlCBPu5WAkm+Kcn13et3JNk6kLp+JMlnkxxIcsEyapqgtncneSDJvUluTdLr8qcl1PULSe5LcneSTy5zlnHfW1YkeWuSSrKUKyV6nLNLkuzvztndSX5uGXX1qa3b523d99r9Sf58KLUluWrdOft8kmcGUteWJHuS3NX9fJ4z80GrahAfjAZWHwZeAxwF3ANsP2ifXwT+qFu+ELh+IHVtBV4P/AlwwcDO2Y8B39Itv2NA5+xV65bPBf5+KOes2++VwCeA24EdQ6gLuAT4w2V9f01Y2zbgLuCYbv07hlLbQfu/k9EFHCuvi1Ff+ju65e3Ao7Med0gt9D63EjgP2N0t3wiclSSrrquqHq2qe4FvLLiWaWrbU1X/3a3ezmi+wBDq+sq61W8FljWY0/eWFb8NXAn8z8DqWoU+tf08cHVVPQ1QVU8NqLb1LgI+MpC6CnhVt/ztwL/NetAhBXqfWwk8t09VHQD+Azh2AHWtyqS1XQr83UIrGulVV5LLkjwM/B7wy0uoq1dtSX4AOLmq5vNMsznV1Xlr9+f5jUlOPsTri9CnttcBr0vyqSS3Jzl7QLUB0HU3ngJ8fCB1vR94e5LHGF0t+M5ZDzqkQNcCJXk7sAP4/VXX8qyqurqqvgt4L/Abq64HIMnLgA8A71l1LYfwN8DWqno9cAvP/7U6BEcy6nb5UUat4D9OcvRKK3qxC4Ebq+rrqy6kcxFwbVWdBJwD/Gn3/Te1IQV6n1sJPLdPkiMZ/Zny5QHUtSq9akvyJuB9wLlV9bWh1LXOdcD5C63oeeNqeyXwfcBtSR4FzgBuXsLA6NhzVlVfXvf1+xDwgwuuqXdtjFqgN1fV/1XVvwCfZxTwQ6jtWReynO4W6FfXpcANAFX1aeAVjO7zMr1lDFz0HEQ4EniE0Z9Ezw4ifO9B+1zGCwdFbxhCXev2vZblDor2OWenMRqc2TawuratW/4pYO9Qajto/9tYzqBon3N2wrrlnwZuH8o5A84GdnfLxzHqbjh2CLV1+50KPEo392YIdTHq/rykW/4eRn3oM9W38P/YhCfhHEa/2R8G3tdt+y1GLUsY/Qb7C2Af8BngNQOp64cYtVD+i9FfDPcP6Jz9I/AkcHf3cfNA6vogcH9X056XCtVl13bQvksJ9J7n7He7c3ZPd85OHco5A8Koq+oB4D7gwqHU1q2/H7hiWTX1PGfbgU91X8+7gR+f9ZjOFJWkRgypD12SNAMDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRvw/UCPsqiPI7IYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist([r['bleu-4'] for r in test_results], bins=100);\n",
    "np.mean([r['bleu-4'] for r in test_results]), np.median([r['bleu-4'] for r in test_results])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.set_split('val')\n",
    "batch_generator = generate_nmt_batches(dataset, \n",
    "                                       batch_size=args.batch_size, \n",
    "                                       device=args.device)\n",
    "batch_dict = next(batch_generator)\n",
    "\n",
    "model = model.eval().to(args.device)\n",
    "sampler = NMTSampler(vectorizer, model)\n",
    "sampler.apply_to_batch(batch_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = []\n",
    "for i in range(args.batch_size):\n",
    "    all_results.append(sampler.get_ith_item(i, False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_results = [x for x in all_results if x['bleu-4']>0.5]\n",
    "len(top_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ4AAAEUCAYAAAAbV1CxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAH/FJREFUeJzt3X28pXO9//HX27g3UtKN0Awaatw0SZNC+R0qqVD8otRJ0XSnOyelfo6KzvmVSkcdlVEdKUeJfppEo9xUKOxhwgyjQTLSKUJIxuz9/v1xXZs1y7659s1a11prv5+Px/XY1813ffdn7T2zPvv7vb7X9yvbREREtMsadQcQERFTSxJPRES0VRJPRES0VRJPRES0VRJPRES0VRJPRES0VRJPRES0VRJPRES0VRJPRES0VRJPRES01Zp1BzAVzJwxI/MS9QDVHcAwNp22bt0hDOt/+v9RdwhDmrPeU+sOYVjn3HTNxP6pLTmn+ufNdgfU8s86LZ6IiGirtHgiInqI+/srl62rFZ8WT0REtFVaPBERvaR/Vd0RjCqJJyKih3igeuJJV1tEREwJafFERPSSMQwuqEtaPBER0VY9mXgkXSppmaTFkm6UNK/h2kaSTpe0XNIt5f5G5bU1JH1Z0g2Srpd0taQt63snERFj4/5Vlbe69EzikbS2pA0aTh1iew6wK/A5SWuX578J3Gr7Oba3Bm4DvlFeOwh4FrCj7R2A1wP3lfU/pR3vIyJiQvpXVd9q0vWJR9LzJH0RWAZsM0SR6cBDQL+k5wAvBI5vuH4csLOkrYFNgbtsDwDYXmH73rLcuZIWSNpXUu6NRUSMU1cmHkkbSHq7pMuAU4GlFK2UaxuKnSHpOoqEdLztfmA2sLjcB6DcXwxsB5wFvK7sovuipBc01LcHcCJwIHCjpH8vE1lERMfwwKrKW126MvEAdwGHAYfb3s32N20/0FTmENs7As8GPiJpxmiV2l4BbAt8HBgALpK0Z3nNti+1/c8UrSYDN0k6YKi6JM2T1Cep74EHHxzv+4yI6DndmngOBO4Efijp2JGSiu2/ANcAL6ZoGc2R9Nj7LvfnlNew/YjtC2wfBfw7sH9D2fUkvRn4IfAq4IPAz4b5vvNt72x75w2nT5/Yu42IqKq/v/pWk65MPLYvtH0QsDtwP/AjST+XNLO5rKT1gRcAt9heDlwLHNNQ5BjgGtvLJe0k6Vnl69YAdgRuL49PoEhOLwWOKpPKybb/1qr3GRExVt0wqq2rb5Lbvgc4CThJ0lygMYWfIelhYB3gNNuLyvOHAV+RdEt5/OvyHMDTgVMlrVMeXwX8Z7l/KXCs7c5cYCQiokt0deJpZPuqhv09Rih3L/CWYa79FPjpMNfOn2CIERGt1wWThHZlV1tERHSvnmnxREQEeKDz52pL4omI6CF1DhqoKl1tERHRVmnxRET0krR4IiIiVpcWT0RED+mGwQVp8URERFulxdMGnZrdZ629Yd0hDGnhuZ+uO4QhbbXPkXWHMKQ/9mcyjbFa9PA9dYfQOl1wjyeJJyKih2Q4dURERJO0eCIieklaPBEREatLiyciood0w3DqJJ6IiF6SrraIiIjVpcUTEdFD3N/5XW1p8URERFsl8UyQpCvqjiEiYpD7V1XeqpC0t6RlkpZLOnqI68+WdImkayVdJ2mf0epMV9sE2X5p3TFERDxmYPIGF0iaBpwMvAJYAVwtaYHtpQ3FjgHOsv01SbOB84GZI9WbFs8ESXqw7hgiIlpkLrDc9q22VwLfA/ZrKmPgSeX+RsAfR6s0LZ6IiB4yyYMLNgPuaDheAby4qcyngAslvR/YANhrtErT4mkRSfMk9Unq+9uDaRRFROdp/Jwqt3njqOZNwGm2Nwf2Ab4jacTckhZPi9ieD8wH2GrGDNccTkRMFWNo8TR+Tg3jTmCLhuPNy3ONDgP2Luv7taR1gU2APw9XaVo8ERExnKuBWZK2lLQ2cDCwoKnMH4A9ASQ9D1gX+MtIlabFExHRQyZzPR7bqyQdASwEpgHfsr1E0nFAn+0FwL8Ap0r6MMVAg0Ntj9jLk8QzQban1x1DRMRjJnnmAtvnUwyRbjx3bMP+UmDXsdSZrraIiGirtHgiInpI5mqLiIhokhZPREQPyUJwERHRXulqi4iIWF1aPBERPSSDCyIiIpqkxdMGt55/Yt0hDOmhKy+oO4QhbbXPkXWHMKQnqTP/u9znyXtSPbqf+wfqDmFUnfk/KSIixqcLEk+62iIioq3S4omI6CEZXBAREdEkLZ6IiB7i/s5fdzItnoiIaKu0eCIiekiGU0dERFt1Q+JJV1tERLRVWjwRET3EAxlc0LUk/b7uGCIielFaPBERPaQbhlMn8QzvLwCS9gA+DdwH7ACcBVwPfBBYD9jf9i01xRgRsRp3/sQF6Wobju0XNRw+H3g38DzgrcA2tucC3wDeX0N4ERFda9TEI2l9Sf8q6dTyeJak17Y+tI5yte27bD8C3AJcWJ6/Hpg51AskzZPUJ6lv/g9+1qYwI2Kqc78rb3Wp0tX2X8Ai4CXl8Z3AD4DzWhVUB3qkYX+g4XiAYX6GtucD8wFYck7nd7pGRLRJla62rW2fADwKYPvvgFoaVUREjMvAQPWtLlVaPCslrQcYQNLWrN4CiIiIDtENgwuqJJ5PAj8FtpB0BrArcGgrg+okti8FLm043mO4axERMbpRE4/tn0m6BtiFoovtg7bvbnlkERExZt3Q4qkyqu31wCrbP7F9HrBK0v6tDy0iInpRlcEFn7R9/+CB7fsout8iIqLDdMPggiqJZ6gymfEgIiLGpUoC6ZN0InByefw+iud6IiKiw/TEPR6KKWFWAt8vt0cokk9ERHSYgQFV3upSZVTbQ8DRbYglIiKmgCqj2raRNF/ShZIuHtzaEVxERIzNZA8ukLS3pGWSlksashEi6Y2SlkpaIum/R6uzyj2eHwBfp5iJuQt6DzvPzH2OrDuEmARXnvTmukMY0o4f+E7dIQxrJZ05TeG0ugPoEpKmUdzffwWwArha0gLbSxvKzAI+Duxq+15JTx+t3iqJZ5Xtr40z7oiIaKNJHlwwF1hu+1YASd8D9gOWNpR5J3Cy7XsBbP95tEqrDC74saT3StpU0saD29jjj4iIVpvkwQWbAXc0HK8ozzXaBthG0uWSfiNp79EqrdLieVv59aiGcwa2qvDaiIjoUJLmAfMaTs0vl3QZizWBWcAewObALyXtUE42MOwLRmR7yzEGERERNRkYQ1fbauuGDe1OYIuG483Lc41WAFfafhS4TdLNFIno6uEqrboC6TGS5pfHU3EF0oiIqehqYJakLSWtDRwMLGgqcy5FawdJm1B0vd06UqVV7vH8F8UDpC8tj+8EPlM57IiIaJvJvMdjexVwBLAQuBE4y/YSScdJ2rcsthC4R9JS4BLgKNv3jFRvlXs8W9s+SNKbykD+LikrkEZEdCBP8owEts8Hzm86d2zDvoEjy62SKi2erEAaERGTpkqL51M8cQXSt7cyqIiIGJ86lzuoqsqotgslLSIrkK5G0qHAhbb/WHcsERHdpMqotots3zO4AqntuyVd1I7gOtyhwLPqDiIiolFXz04taV1gfWATSU+haO0APIknPrna9STNBC4ALqMYwXcnxdQQ21LMVbc+cAvwDmBPYGfgDEkPAy+x/XD7o46I6D4jtXjeRbHg23PLr4Pbj4D/bH1otZhFMefQdsB9wAHA6cDHbO8IXE+xFPjZQB9wiO05SToR0Sm6usVj+yTgJEnvt/2VNsZUp9tsLy73FwFbA0+2/Yvy3LcpZuseVeNUFBtvvDEbTp8+2bFGRDxBf40Jpaoqgwu+IumlwMzG8rZPb2FcdWkcJt4PPHm8FTVORTFzxozOnBs+IqIGoyYeSd+h+Mt/MY+vx2OKLqhedz9wr6Tdbf8KeCsw2Pp5ANiwtsgiIoZQZxdaVVWe49kZmF0+nToVvQ34uqT1KeYfGnyG6bTyfAYXRESMQZXEcwPwTOCuFsdSK9u/B7ZvOP5Cw+Vdhih/DnBO6yOLiKhuwL3R4tkEWCrpKhrugdjed/iXREREHXpi5gKKKXMiIiImRZVRbb+QNAOYZfvn5b2Oaa0PLSIixqq/C7raqkyZ807gbOCU8tRmFAv/REREjFmVrrb3AXOBKwFs/07S01saVUREjEuvDKd+xPbKwbXfJK1JuTZPRER0lp7oagN+IekTwHqSXkExZcyPWxtWRET0qiqJ52jgLxQTZL6LYgnUY1oZVEREjM+AVXmrS5VRbQPAqZK+DWwH3DmFZzGIKWzbD3TmLFFV+svrshad3+0T7Tdsi0fS1yVtV+5vRDFX2+nAtZLe1Kb4IiJiDPqtyltdRupq2932knL/7cDNtncAXgh8tOWRRURETxqplb6yYX9wUAG2/zQ4wi0iIjpLfxfcCBkp8dwn6bUUS0DvChwGjw2nXq8NsUVExBh1+ySh7wK+TDEz9Yds/6k8vyfwk1YHFhERvWmkpa9vBvYe4vxCYGErg4qIiPHplQdIIyIiJk0nPwIQERFj1O2DCyIiosv0d8FDu1WWRXiGpG9KuqA8ni3psNaHFhERvajKPZ7TKAYTPKs8vhn4UKsCioiI8et39a0uVRLPJrbPAgYAbK8C+lsaVYeTdK6kRZKWSJpXdzwREd2kyj2ehyQ9lXINHkm7APe3NKrO9w7bf5W0HnC1pHNs31N3UBER3dAqqJJ4jgQWAFtLuhx4GnBgS6PqfB+Q9PpyfwtgFrBa4ilbQvMANt54YzacPr29EUbElNQTicf2NZJeDmwLCFhm+9GWR9ahJO0B7AW8xPbfJV0KrNtczvZ8YD7AzBkzumCAY0REe1QZ1fY+YLrtJbZvAKZLem/rQ+tYGwH3lknnucAudQcUETGoH1Xe6lJlcME7bd83eGD7XuCdrQup4/0UWFPSjcBngd/UHE9ERMtI2lvSMknLJR09QrkDJFnSzqPVWeUezzRJGlx1VNI0YO3qYfcW248Ar647joiIofRP4gLR5ef9yRRL46ygGEy1wPbSpnIbAh8ErqxSb5UWz0Lg+5L2lLQncCbFX/0REdHb5gLLbd9qeyXwPWC/IcodD3wO+EeVSqskno8CFwPvKbeLyAqkEREdqX8Mm6R5kvoatubnEjcD7mg4XlGee4yknYAtbFdeLmfErraymXW67UOAr1etNCIi6jGW4dSNo2/HQ9IawInAoWN53YgtHtv9wAxJU/aeTkTEFHYnxbOKgzYvzw3aENgeuFTS7ylG+S4YbYBBlcEFtwKXS1oAPDR40vaJ1eKOiIh2meQHSK8GZknakiLhHAy8efCi7fuBTQaPy+caP2K7b6RKqySeW8ptDYrsFhERU4DtVZKOoBhkNg34lu0lko4D+mwvGE+9VWYu+PR4Ko6IiPbrZ3InSrF9PnB+07ljhym7R5U6R008ki6BJ74T2/9U5RtERET79MRcbcBHGvbXBQ4AVrUmnIgYq3O3e2bdIQxrm/fvW3cIQ/r18QvrDmFKq9LVtqjp1OWSrmpRPBERMQGTOXNBq1Tpatu44XAN4IUUE2VGRESMWZWutkUU93hE0cV2G3BYK4OKiIjx6Yl7PLa3bEcgERExcZM9qq0VqnS1rUUxR9vLylOXAqdM5cXgIiJi/Kp0tX0NWAv4ann81vLc4a0KKiIixqcnWjzAi2w/v+H4Ykm/bVVAERHR26osi9AvaevBA0lb0R33ryIippyxLItQlyotnqOASyTdSjGybQbw9pZGFRERPWvUFo/ti4BZwAeA9wPb2r6k1YG1gqQnS3pvub+HpPOGKfcNSbNHqWt3SUskLZa0XivijYgYq3678laXYROPpBdJeiaA7UeAORTLm36+6aHSbvJk4L2jFbJ9ePOa4kM4BPi/tufYfnhSoouImKB+XHmry0gtnlOAlQCSXgZ8FjgduJ8JrFhXs88CW0taDHwemC7pbEk3STpDkqBYU2JwISNJr5T0a0nXSPqBpOmSDgfeCBwv6Yza3k1ERBca6R7PNNt/LfcPAubbPgc4p/zg7kZHA9vbniNpD+BHwHbAH4HLgV2BywYLS9oEOAbYy/ZDkj4GHGn7OEm7AefZPrvdbyIiYjjdMJx6pBbPNEmDiWlP4OKGa1UGJXSDq2yvsD0ALAZmNl3fBZhNMTHqYuBtFIMrRiVpnqQ+SX0PPPjgZMYcEdHVRkogZwK/kHQ38DDwKwBJz6HobusFjzTs9/PEn4eAn9l+01grtj2fskty5owZnf8nSET0hIFunp3a9r9JugjYFLjQfuzdrEExuq0bPcDYlu/+DXCypOfYXi5pA2Az2ze3JryIiInphq62EbvMbP9miHNd+6Fr+x5Jl0u6gaIV9z+jlP+LpEOBMyWtU54+Bujan0FERN165V5NZbbfPMz5Ixr292jYvxh40RDlD21BeBERE9INLZ4qU+ZERERMminX4omI6GU9sfR1RER0j3S1RURENEmLJyKih3TDczxp8URERFulxRMR0UO64R5PEk9ERA9J4omIlttvyZ/qDmFYA+/uzBVUOvkewy11B9AGSTwRET0kgwsiIiKapMUTEdFDuuEeT1o8ERHRVmnxRET0kMzVFhERbTWQrraIiOhmkvaWtEzScklHD3H9SElLJV0n6SJJM0arM4knIqKH9NuVt9FImgacDLwamA28SdLspmLXAjvb3hE4GzhhtHqTeCIiYjhzgeW2b7W9EvgesF9jAduX2P57efgbYPPRKs09noiIHjLJD5BuBtzRcLwCePEI5Q8DLhit0iSeiIgeMpbneCTNA+Y1nJpve1zzHEl6C7Az8PLRyibxRERMUWWSGSnR3Als0XC8eXluNZL2Av4P8HLbj4z2fZN4IiJ6yIAHJrO6q4FZkrakSDgHA29uLCDpBcApwN62/1yl0gwuaBFJ8yT1Sep74MEH6w4nImLMbK8CjgAWAjcCZ9leIuk4SfuWxT4PTAd+IGmxpAWj1St3wVOu3W7mjBn5IUfLTKs7gBFM6t/ek6iT/+K+5fbbNZHXv2rW9pU/bxb+7oYJfa/xSldbREQP6YYpczo58XcFSedLelbdcUREdIu0eCbI9j51xxARMShztUVERDRJiycioodk6euIiIgmafFERPSQTh3C3iiJJyKih6SrLSIioklaPBERPSTDqSMiIpqkxRMRLdOpf3t3ww348eqGezxJPBERPSRdbREREU3S4omI6CFp8URERDRJiycioocMdH6DJ4knIqKXpKstIiKiSVo8ERE9JC2eiIiIJmnxRET0kC6YuKB3WzySLpW0TNLicju74do8STeV21WSdmu49lpJ10r6raSlkt5VzzuIiOhNPdXikbQ2sJbth8pTh9juayrzWuBdwG6275a0E3CupLnAPcB8YK7tFZLWAWaWr3uK7Xvb9V4iIsYj93jaRNLzJH0RWAZsM0rxjwFH2b4bwPY1wLeB9wEbUiTje8prj9heVr7uIEk3SPoXSU9rxfuIiJgoj2GrS9cmHkkbSHq7pMuAU4GlwI62r20odkZDV9vny3PbAYuaqusDtrP9V2ABcLukMyUdImkNANtfB14NrA/8UtLZkvYevD5EfPMk9Unqe+DBByftfUdEdLtu7mq7C7gOONz2TcOUeUJX22hsHy5pB2Av4CPAK4BDy2t3AMdL+gxFEvoWRdLad4h65lN02zFzxozOb/tGRE9IV1trHQjcCfxQ0rGSZlR83VLghU3nXggsGTywfb3tL1EknQMaC5b3gr4KfBk4C/j4+MKPiJiaujbx2L7Q9kHA7sD9wI8k/VzSzFFeegLwOUlPBZA0h6JF81VJ0yXt0VB2DnB7We6Vkq4DPgNcAsy2/SHbS4iI6BDdcI+nm7vaALB9D3AScFLZGulvuHyGpIfL/btt72V7gaTNgCskGXgAeIvtuyRtCHxU0inAw8BDlN1sFAMOXmf79ja8rYiIcen8jjaQu+Fpoy6XezzRStPqDmAE/aMXqYXqDmAEt91++4TC22bGzMqfNzff/vtafhRd3+KJiIjHZXBBREREk7R4IiJ6SOe3d5J4IiJ6SjcknnS1RUTEsMoZWpZJWi7p6CGuryPp++X1Kys80pLEExHRSybzOR5J04CTKWZqmQ28SdLspmKHAffafg7wJeBzo9WbxBMREcOZCyy3favtlcD3gP2ayuxHMdEywNnAnpJGHKadxBMR0UMmeeaCzYA7Go5XlOeGLGN7FcVMMk8dqdIMLmiD30/wgbBGkuaVE5B2lMQ1dp0aW+Iam06LayyfN5LmAfMaTs1vx3tJi6f7zBu9SC0S19h1amyJa2w6Na5R2Z5ve+eGrTnp3Als0XC8eXluyDKS1gQ2olzTbDhJPBERMZyrgVmStixXeD6YYs2yRguAt5X7BwIXe5S52NLVFhERQ7K9StIRwEKKaQG/ZXuJpOOAPtsLgG8C35G0HPgrRXIaURJP9+mYvuQmiWvsOjW2xDU2nRrXpLB9PnB+07ljG/b/AfzvsdSZ2akjIqKtco8nIiLaKoknIiLaKomnC0jaTdLby/2nSdqy7pgiIsYriafDSfok8DHg4+WptYDv1hdRd5H0FEk71h3HUCStIelJdccRk0fSM+uOoRsk8XS+1wP7Ag8B2P4jsGGtEZUkbSPpIkk3lMc7SjqmA+K6VNKTJG0MXAOcKunEuuMCkPTfZWwbADcASyUdVXNMHfl7hM6ObRjfrDuAbpDE0/lWlg9jGaD8wOoUp1K0xB4FsH0dFcbwt8FGtv8GvAE43faLgb1qjmnQ7DK2/YELgC2Bt9YbUsf+HqGzY3sC26+pO4ZukMTT+c6SdArwZEnvBC4CvlFzTIPWt31V07lVtUSyujUlbQq8ETiv7mCarCVpLYrEs8D2o9S/dlen/h6hs2OLccoDpB3O9hckvQL4G7ANcIztn9cc1qC7JW3N462xA4G76g0JgOMonrS+zPbVkrYCfldzTINOAX4P/Bb4paQZFL/bOnXq7xE6O7YYpzxA2qEkXWZ7N0kPUPyna5xxdoBiaorP2/5qLQEC5Qf6fOClwL3AbcAhtm+vK6ZuJGnNcjr5ur5/x/4eOzm2GL8kni4l6anAFba3rTGGdSgmBZwJbEzxl7ttH1dTPB+1fYKkrzBE95XtD9QQ1hNIeg2wHbDu4Lk6fmaSjmw6tR5F9/vgQJbaBmR0cmwxcelq61K275G0R81h/Ai4j2Lk2B9rjgXgxvJrX61RjEDS14H1gf9Fca/uQKD5Hka7DI6O3BZ4EcXvUxSDHeqKaVAnxxYTlBZPjJukG2xvX3cc3UTSdbZ3bPg6HbjA9u41xvRL4DW2HyiPNwR+YvtldcU0qJNji/FLiycm4gpJO9i+vu5AGkm6hKG72v6phnCaPVx+/bukZ1EsmLVpjfEAPANY2XC8sjzXCTo5thinJJ6YiN2AQyXdBjxC0RVi23XPFPCRhv11gQPonCG450l6MnACsKg8V/fw+NOBqyT9v/J4f+C0+sJZTSfHFuOUrrYYt3Io8BN04ogjSVfZntsBcawHvAfYnaJV9ivga+WaJnXGtVMZE8AvbV9bZzyNOjm2GJ8knug55VQ5g9YAdgZOqnME4CBJZwEP8Ph8e2+mmGnhjfVFFdFe6WqLXrSIx599epTigc3D6gyowfa2ZzccXyJpaW3RRNQgU+ZEL/oYMMf2lsB3KJ79+Hu9IT3mGkm7DB5IejEdPPw7ohXS1RY9p2Go8m7A8cAXgGPLyUJrJelGimdT/lCeejawjGLwQycMzIhouXS1RS/qL7++BjjV9k8kfabOgBrsXXcAEXVLiyd6jqTzgDuBVwA7UTw7c5Xt59caWEQASTzRgyStT9GyuN7278olEnawfWHNoUUESTwREdFmGdUWERFtlcQTERFtlVFt0XPKtYouKg+fSTHK7S/l8VzbK4d84cS+507A023/dIhr0ynmY9uO4qHWe4FX2R7zs0WS3gAstX3TBEOOqE0ST/Qc2/cAcwAkfQp40PYXqr5e0jTb/aOXXM1OwPbAExIP8GHgD7YPLut/LsWMCuPxBooVaJN4omulqy2mFEk/lrRI0hJJh5fn1pR0n6T/kHQdMFfSvpKWlWW/Iuncsux0SadJukrStZJeV078eSxwiKTFkg5s+rabUgzvBsD2TbYfLet7W1nXYklflbRGQzyflfRbSb+W9HRJuwP7AF8qy8+UNEvSwjLOX0rapqz3u5JOknSFpFslvb7hZ/AJSdeXdf9beW7IeiJawna2bD27AZ8CPtJwvHH5dX1gKfAUipa/gTc0XFsBzKDoGvsBcG557QTg4HL/KcDNFEsvHA78xzAxvJCiq+8KipkUnlOe3x44F1izPJ5PMWnoYDyvLs+fCBxd7n8X2L+h7kuArcv9XYELG8qdWca/I3BTef51FDNir9f08xiynmzZWrGlqy2mmg9L2rfc3xzYGlhMscDY4Jovs4FlLpd3kHQm8M/ltVcCr5Z0dHm8LsW0N8OyvUjSVuVr9wL6JM0t919UHgOsB9xRvuxh2xeU+4t4fFmAx5Tr+uwCnFO+HlbvPj/XtoHrJG1WntsL+Jbth8vY/lqhnohJlX9cMWVI2gt4GbCL7YclXUaROKD4oK/yUJsoWhy3NNU94lLMLpZuPofiw13Aq8u6vmX7X5vqWpPVV93sZ+j/qwLutj1nmG/7SFPZ4YxWT8Skyj2emEo2Av5aJp3tKFobQ1kKbCtpizJJHNRwbSHw/sEDSS8odx8ANhyqMkm7la0KJK0DPA+4Hfg58EZJm5TXnippxNZT4/exfS9w1+D9m/L+0GjTAv0MeEd5XwpJG4+znohxS+KJqeQnwPrl+jefAa4cqpCLYc5HUCSGPuA+4P7y8qeBDcqb80so7iEBXAw8vxxw0Dy4YBbwK0nXA9cAvwZ+ZPv6sr6fl4MaLgSeMcp7OBP4xODgAuBg4N2SfgssAV470ottn0cx8q5P0mKKEXeMtZ6IiciUORFDkDTd9oNli+cUinnfvlJ3XBG9IC2eiKG9p2wRLKW46X9qzfFE9Iy0eCIioq3S4omIiLZK4omIiLZK4omIiLZK4omIiLZK4omIiLZK4omIiLb6/20splF7QmyUAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ4AAAEUCAYAAAAbV1CxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHPFJREFUeJzt3Xm4XWV99vHvTRAZgpTBAQEDMliCAkKMlkGpYAuUQYUKiG/VF4xvK7xOoKgUKVIvBcUiBSUMrwV5QRDFWCKjoKJWSCCAiQRDIIUUq4wSRELOufvHWodsjmdY5yR7r7PXvj/Xta6s4dnP/u19kfz4PevZz5JtIiIiOmWNugOIiIjeksQTEREdlcQTEREdlcQTEREdlcQTEREdlcQTEREdlcQTEREdlcQTEREdlcQTEREdlcQTEREdtWbdAfSCLadMybpE8QL3XXZi3SFMGJNeslHdIUwsOxyiVXr9/Cur/3uzqu81Tql4IiKio1LxREQ0iPv6KretpdwhFU9ERHRYKp6IiCbpW1F3BKNK4omIaBD3V088GWqLiIiekIonIqJJxjC5oC6peCIioqNS8URENIgzuSAiIjqqCxJPhtoiIqKjUvFERDTIWKZT1yUVT0REdFQjE4+kmyUtlDRP0q8kzWi5toGkiyQtknRfub9BeW0NSV+V9EtJd0u6TdJW9X2SiIgx6uurvtWkMYlH0lqS1ms5daTtnYHdgS9KWqs8fwGw2PY2trcG7gfOL68dBrwS2NH264B3AE+U/W/Yic8REbEq3Lei8laXrk88kraX9GVgIbDdEE0mA08DfZK2AXYFPtdy/RRgmqStgU2Bh233A9h+yPbjZburJM2SdJCk3BuLiBinrkw8ktaT9H5JtwDnAQsoqpQ7WppdIukuioT0Odt9wFRgXrkPQLk/D9gBuBw4sByi+7Kk17f0txdwBnAo8CtJny8TWUTExNG3ovpWk65MPMDDwFHA0bb3sH2B7acGtTnS9o7Aq4DjJE0ZrVPbDwGvAT4F9AM3Stq7vGbbN9v+O4qqycA9kg4Zqi9JMyTNkTTnqWXLxvs5IyIap1sTz6HAUuA7kk4aKanY/h1wO/BGispoZ0nPf+5yf+fyGraftf0D28cDnwfe3tJ2HUnvBr4D/DXwYeD6Yd53pu1ptqetP3nyqn3aiIiK3N9XeatLVyYe29fZPgzYE3gS+J6kGyRtObitpHWB1wP32V4E3AG0PvD+ROB224sk7SLpleXr1gB2BJaUx6dRJKfdgOPLpHK27d+363NGRIxVN0wu6Oqb5LYfBc4EzpQ0HWhN4ZdIegZ4MfAN23PL80cBZ0m6rzz+eXkO4GXAeZJeXB7fCvxruX8zcJLtP7blw0RE9IiuTjytbN/asr/XCO0eB94zzLVrgGuGuTZ7FUOMiGi/rNUWERHxQo2peCIiglonDVSViiciIjoqFU9ERJN0wT2eJJ6IiAbphieQZqgtIiI6KhVPRESTpOKJiIh4oVQ8EREN0g3TqZN4IiKaJENtERERL5SKJyKiQdyXobYAVHcAE8SC04Z8Zl5P2ubwU+sOYcJw3QFMMA8saf7fkySeiIgGyQ9IIyKis/pXVN8qkLSvpIWSFkk6YYjrr5J0k6Q7JN0laf/R+kziiYiIIUmaBJwN7AdMBY6QNHVQsxOBy22/HjgcOGe0fjPUFhHRIKt5csF0YJHtxQCSLgMOBha0viXwknJ/A+C/Rus0iSciIoazGfBgy/FDwBsHtTkZuE7SscB6wD6jdZqhtoiIJunrq7xJmiFpTss2YxzveATwDdubA/sDF0saMbek4omI6FG2ZwIzR2iyFNii5Xjz8lyro4B9y/5+LmltYBPgt8N1moonIqJB3Lei8lbBbcC2kraStBbF5IFZg9r8J7A3gKTtgbWB343UaSqeiIgmWY2TC2yvkHQMcC0wCbjQ9nxJpwBzbM8CPg6cJ+mjFBMN3md7xN8FJ/FERMSwbM8GZg86d1LL/gJg97H0mcQTEdEg3bBWW+7xRERER6XiiYhokDwILiIiOitDbc0n6Wd1xxAR0U1S8awi27vVHUNExIBMLugBkpbVHUNERDdJxRMR0SDu6687hFEl8bRJudjeDICNN9qI9SdPrjmiiOgJXZB4MtTWJrZn2p5me1qSTkTESql4IiIaJJMLIiIiBknFs4psZxwtIiYM9424MPSEkIonIiI6KhVPRESDZDp1RER0VDckngy1RURER6XiiYhoEPdnckFERMQLpOKJiGiQbphOncQTEdEgnvgLF2SoLSIiOmvUxCNpXUn/KOm88nhbSQe0P7SIiBgr97nyVpcqFc//A54F/qI8Xgqc2raIIiKi0aoknq1tnwY8B2D7D4DaGlVERIxLf3/1rS5VJhcsl7QOYABJW1NUQBERMcF0w+SCKonns8A1wBaSLgF2B97XzqCa5v7ZZ9QdwoSw24GfrjuECWPiT3iNaJ9RE4/t6yXdDryJYojtw7YfaXtkERExZt1Q8VSZ1fYOYIXtq23/O7BC0tvbH1pERDRRlckFn7X95MCB7Scoht8iImKC6YbJBVUSz1BtsuJBRESMS5UEMkfSGcDZ5fGHgLntCykiIsarEfd4gGOB5cC3yu1ZiuQTERETTH+/Km91qTKr7WnghA7EEhERPWDUxCNpO+A4YMvW9rbf2r6wIiJiPOqcNFBVlXs8VwBfB84HumD0MCIiVhdJ+wJnApOA821/YYg27wJOpvht9J223z1Sn1USzwrbXxt7uBER0Wmrc3KBpEkUE8veBjwE3CZplu0FLW22BT4F7G77cUkvG63fKpMLvi/pHyRtKmmjgW2cnyMiItpoNU8umA4ssr3Y9nLgMuDgQW0+AJxt+3EA278drdMqFc97yz+Pbzln4NUVXhsREd1rM+DBluOHgDcOarMdgKSfUgzHnWz7mpE6rTKrbauxxRkREXXpH8NQm6QZwIyWUzNtzxzjW64JbAvsBWwO/FjS68pVboZ9wWiBrQt8DHiV7RnleN5rynXbGkvSA7a3rDuOiIh2KZPMSIlmKbBFy/Hm5blWDwG/sP0ccL+keykS0W3DdVr1CaTLgd1aAskTSCMiJqDVfI/nNmBbSVtJWgs4HJg1qM1VFNUOkjahGHpbPFKneQLp8H4HIGkvST+S9D1JiyV9QdKRkm6VdHf5YLyIiAnB/aq8jdqXvQI4BrgW+BVwue35kk6RdFDZ7FrgUUkLgJuA420/OlK/eQLpMGy/oeVwJ2B74DGKTH6+7emSPkyxpNBHaggxIqLtbM8GZg86d1LLvilux3ysap9VKp6TeeETSG8EPln1DRriNtsP234WuA+4rjx/N8WKDn9C0gxJcyTNmXnF9R0KMyJ6XTc8FqHKrLbrJM2lt59A2lrh9bcc9zPMd/iCm3bzr8yTjiMiSlVmtd1oe2/g6iHORUTEBFLnqtNVDZt4JK0NrAtsImlDVk4oeAnFj4oiIiLGbKSK54MUN81fSfHgt4HE83vgX9sc14Rh+2bg5pbjvYa7FhFRt66ueGyfCZwp6VjbZ3UwpoiIGKe+bk48A2yfJWk3/vR5PBe1Ma6IiGioKpMLLga2Buax8nk8BpJ4IiImmK4eamsxDZha/kgoIiJilVRJPL8EXgE83OZYIiJiFfW7GRXPJsACSbfS8kNK2wcN/5KIiKhDnSsSVFUl8Zzc7iAiIqJ3VJnV9iNJU4Btbd9QPp9nUvtDi4iIserrgqG2URcJlfQB4NvAueWpzSievxARETFmVYbaPgRMB34BYPvXkl7W1qgiImJcmjKd+lnby6Xiw0hak/LZPBERMbE0YqgN+JGkTwPrSHobcAXw/faGFRERTVUl8ZxA8RjouykWDp0NnNjOoCIiYnz6rcpbXarMausHzpP0b8AOwNKsYjA2W+5f+YmwERGNN2zFI+nrknYo9zegWKvtIuAOSUd0KL6IiBiDPqvyVpeRhtr2tD2/3H8/cK/t1wG7Ap9oe2QREdFIIw21LW/ZH5hUgO3fDMxwi4iIiaWvC26EjJR4npB0ALAU2B04Cp6fTr1OB2KLiIgx6vZFQj8IfJViZeqP2P5NeX5v4Op2BxYREc000qOv7wX2HeL8tcC17QwqIiLGpyk/II2IiFhtqiyZExERXaLbJxdERESX6aMBQ22SXi7pAkk/KI+nSjqq/aFFREQTVbnH8w2KyQSvLI/vBT7SroAiImL8+lx9q0uVxLOJ7cuBfgDbK4C+tkYVERGNVeUez9OSNqZ8Bo+kNwFPtjWqiIgYl26oCqokno8Bs4CtJf0UeClwaFujioiIcWlE4rF9u6S3AK8BBCy0/VzbIxsDSctsT14N/ZwMLLP9pU6/d0REr6gyq+1DwGTb823/Epgs6R/aH1pERIxVH6q81aXK5IIP2H5i4MD248AH2hfS+KlwuqRfSrpb0mHl+cmSbpR0e3n+4JbXfEbSvZJuoajqBs5vLekaSXMl/UTSn5fnt5L087KfUzv+ISMiOkjSvpIWSlok6YQR2h0iyZKmjdZnlcQzSS3PQZA0CVirWsgd905gZ2AnYB/gdEmbAn8E3mF7F+AvgS+XSWpX4PDyNfsDb2jpayZwrO1dgeOAc8rzZwJfK59N9PBwgUiaIWmOpDlPLVu2Wj9kRMRw+uzK22jKf+/PBvYDpgJHSJo6RLv1gQ8Dv6gSY5XEcy3wLUl7S9obuBS4pkrnNdgDuNR2n+3/Bn5EkUwEfF7SXcANwGbAy4E9ge/a/oPt31NMokDSZGA34ApJ84BzgU3L99id4jsAuHi4QGzPtD3N9rT1J+cWUER0penAItuLbS8HLgMOHqLd54AvUvxP/qiqzGr7BDAD+Pvy+Hrg/CqdTyBHUszG29X2c5IeANYeof0awBO2dx7meheshhQRvWg1z2rbDHiw5fgh4I2tDSTtAmxh+2pJx1fpdMSKpyyzLrb9dduHltu5tifqjL2fAIdJmiTppcCbgVuBDYDflknnL4EpZfsfA2+XtE5ZKh4IUFY/90v6W3j+3tFO5Wt+SjE8B0VCi4iYMPrGsLXeEii3GWN5L0lrAGcAHx/L60aseGz3SZoiaa2yzJrovgv8BXAnRVXyifJR3ZcA35d0NzAHuAeenyr+rbL9b4HbWvo6EviapBOBF1GUmHdSjGP+f0mfBL7XmY8VEbH62Z5JcT97OEuBLVqONy/PDVgfeC1wczkV4BXALEkH2Z4zXKfyKDeYJF0EbE9x/+PploDPGPGF8bwtp0zJ0FxEVPLAkiWrNM/5zJ3eVvnfmw/fef2I7yVpTYr1OfemSDi3Ae+2PX+Y9jcDx42UdKDaPZ77ym0NiuwWERE9wPYKScdQTDKbBFxoe76kU4A5tmeNp98qKxf803g6joiIzutbzXOfbM8GZg86d9Iwbfeq0ueoiUfSTQwxi8v2W6u8QUREdM5EnfnVqspQ23Et+2sDhwAr2hNOREQ0XZWhtrmDTv1U0q1tiiciIlZBlRUJ6lZlqG2jlsM1gF0pfhcTERExZlWG2uZS3OMRxRDb/cBR7QwqIiLGpxH3eGxv1YlAIiJi1a3uWW3tUGWo7UUU67S9uTx1M3DuRHsYXEREdIcqQ21fo1gyZuCxAP+rPHd0u4KKiIjxaUTFA7zB9k4txz+UdGe7AoqIiGar8jyePklbDxxIejXdcf8qIqLnjGV16rpUqXiOB26StJhiZtsU4P1tjSoiIhqryqy2GyVtC7ymPLXQ9rPtDSsiIsajq39AKukNwIO2f2P7WUk7UyyXs0TSybYf61iUERFRSTdMLhjpHs+5wHIASW8GvgBcBDzJyA8OioiIGNZIQ22TWqqaw4CZtq8ErpQ0r/2hRUTEWHV7xTOpfPocFE+f+2HLtSqTEiIiIv7ESAnkUuBHkh4BngF+AiBpG4rhtoiImGD6u3lyge1/lnQjsClwnf38p1kDOLYTwUVExNh0w1DbiENmtv9jiHP3ti+ciIhoutyriYhokG6oeKosmRMREbHapOKJiGiQrl65ICIiuk+G2iIiIgZJxRMR0SDd8DueVDwREdFRqXgiIhqkG+7xJPFERDRINySeDLVFRERHJfG0iaQZkuZImvPUsmV1hxMRPaLfrrzVJYmnTWzPtD3N9rT1J0+uO5yIiAkj93giIhok93h6gKTZkl5ZdxwREd0iFc8qsr1/3TFERAzIWm0REdFR/Rlqi4iIbiZpX0kLJS2SdMIQ1z8maYGkuyTdKGnKaH0m8URENEifXXkbjaRJwNnAfsBU4AhJUwc1uwOYZntH4NvAaaP1m8QTERHDmQ4ssr3Y9nLgMuDg1ga2b7L9h/LwP4DNR+s093giIhpkNf8wdDPgwZbjh4A3jtD+KOAHo3WaxBMR0SBj+R2PpBnAjJZTM23PHM/7SnoPMA14y2htk3giInpUmWRGSjRLgS1ajjcvz72ApH2AzwBvsf3saO+bxBMR0SD97l+d3d0GbCtpK4qEczjw7tYGkl4PnAvsa/u3VTrN5IKIiBiS7RXAMcC1wK+Ay23Pl3SKpIPKZqcDk4ErJM2TNGu0flPxREQ0yOr+Aant2cDsQedOatnfZ6x9JvFERDRINyyZk6G2iIjoqFQ8ERENkrXaIiIiBknFExHRIHU+0rqqVDwREdFRqXgiIhpktf58tE2SeCIiGiRDbREREYOk4omIaJBMp46IiBgkFU9ERIN0wz2eJJ6IiAbJUFtERMQgqXgiIhokFU9ERMQgqXgiIhqkf+IXPM2teCTdLGlh+SjWeZK+3XJthqR7yu1WSXu0XDtA0h2S7pS0QNIH6/kEERFj148rb3VpVMUjaS3gRbafLk8daXvOoDYHAB8E9rD9iKRdgKskTQceBWYC020/JOnFwJbl6za0/XinPktERFM1ouKRtL2kLwMLge1Gaf5J4HjbjwDYvh34N+BDwPoUyfjR8tqztheWrztM0i8lfVzSS9vxOSIiVlU3VDxdm3gkrSfp/ZJuAc4DFgA72r6jpdklLUNtp5fndgDmDupuDrCD7ceAWcASSZdKOlLSGgC2vw7sB6wL/FjStyXtO3A9IiKq6eahtoeBu4Cjbd8zTJs/GWobje2jJb0O2Ac4Dngb8L7y2oPA5ySdSpGELqRIWgcN7kfSDGAGwEYbbcT6kyePJYyIiHHpgoULurfiAQ4FlgLfkXSSpCkVX7cA2HXQuV2B+QMHtu+2/RWKpHNIa8PyXtA5wFeBy4FPDfUmtmfanmZ7WpJORMRKXZt4bF9n+zBgT+BJ4HuSbpC05SgvPQ34oqSNASTtTFHRnCNpsqS9WtruDCwp2/2VpLuAU4GbgKm2P2J7PhERE0Q33OPp5qE2AGw/CpwJnFlWI30tly+R9Ey5/4jtfWzPkrQZ8DNJBp4C3mP7YUnrA5+QdC7wDPA05TAbxYSDA20v6cDHiogYly4YaUPuhgHBLrfllCn5kiOikgeWLNGqvH7qlK0q/3uzYMn9q/Re49X1FU9ERKyUtdoiIiIGScUTEdEgE7/eSeKJiGiUbkg8GWqLiIiOSsUTEdEgmVwQERExSCqeiIgGmfj1ThJPRESjdEPiyVBbREQMq3z8y0JJiySdMMT1F0v6Vnn9FxXWy0ziiYhoEo9hG42kScDZFI+BmQocIWnqoGZHAY/b3gb4CvDF0fpN4omIiOFMBxbZXmx7OXAZcPCgNgdTPMUZ4NvA3pJGXAMuiSciokFWZ8UDbAY82HL8UHluyDa2V1A8pmbjkTrN5IIOWNXVZlcHSTNsz6w7jokg38VK+S5Wasp3MZZ/b1qflFya2YnvIBVP75gxepOeke9ipXwXK/Xcd9H6pORyG5x0lgJbtBxvXp4bso2kNYENKJ5fNqwknoiIGM5twLaStpK0FnA4MGtQm1nAe8v9Q4EfepQHvWWoLSIihmR7haRjgGuBScCFtudLOgWYY3sWcAFwsaRFwGMUyWlESTy9o+vHrlejfBcr5btYKd/FEGzPBmYPOndSy/4fgb8dS5959HVERHRU7vFERERHJfFERERHJfE0mKQ9JL2/3H+ppK3qjikiIomnoSR9Fvgk8Kny1IuAb9YX0cQhaUNJO9YdR0xMkl5RdwxNl8TTXO8ADgKeBrD9X8D6tUZUI0k3S3qJpI2A24HzJJ1Rd1x1kLS7pOsl3StpsaT7JS2uO64J5IK6A2i6TKduruW2LckAktarO6CabWD795KOBi6y/VlJd9UdVE0uAD4KzAX6ao5lwrH9N3XH0HRJPM11uaRzgT+T9AGKpcvPrzmmOq0paVPgXcBn6g6mZk/a/kHdQUTvSuJpKNtfkvQ24PfAdsCJtm+oOaw6nULx6+tbbN8m6dXAr2uOqS43STod+A7w7MBJ27fXF1L0kvyAtGEk3WJ7D0lPUax83rpSbT/Fkhan2z6nlgCjdpJuKncH/vILsO231hRS9JhUPA1je4/yzyEnEkjaGPgZ0BOJR9InbJ8m6SyGeASJ7f9bQ1h1u3mIc/k/0OiYJJ4eY/tRSXvVHUcH/ar8c06tUUwsy1r21wYOYOX3FNF2GWqL6HGSXgxca3uvumOJ3pCKJ3pCeV9jqKG23NeAdSke8BXREUk80SuOa9lfGzgEWFFTLLWSdDcrk/Ak4KUUs/4iOiJDbdGzJN1qe3rdcXSapCkthyuA/7bdk0k46pGKJ3pCuVTOgDWAaRTPhu85tpfUHUP0tiSe6BVzWfm7pueAByhWc4iIDssiodErPgnsbHsr4GKKxVP/UG9IEb0piSd6xYnlIqF7AG+lWLfuazXHFNGTkniiVwyswvw3wHm2rwbWqjGeiJ6VxBO9Ymm5WvdhwOzyR5P57z+iBplOHT1B0rrAvsDdtn9dPiLhdbavqzm0iJ6TxBMRER2VoYaIiOioJJ6IiOio/IA0Gqd85tCN5eErKGa0/a48nm57eRvecxfgZbavGeLaZIrp2ztQ/ID1ceCvbY/5d0SS3gkssH3PKoYcUZsknmgc248COwNIOhlYZvtLVV8vaZLtvtFbvsAuwGuBP0k8wEeB/7R9eNn/n1OsnjAe76R4kmwST3StDLVFT5H0fUlzJc2XdHR5bk1JT0j6F0l3AdMlHSRpYdn2LElXlW0nS/qGpFsl3SHpQEnrACcBR0qaJ+nQQW+7KbB04MD2PbafK/t7b9nXPEnnSFqjJZ4vSLpT0s8lvUzSnsD+wFfK9ltK2lbStWWcP5a0XdnvNyWdKelnkhZLekfLd/BpSXeXff9zeW7IfiLawna2bI3dgJOB41qONyr/XBdYAGxIUfkbeGfLtYeAKRRDY1cAV5XXTgMOL/c3BO6leMzC0cC/DBPDrhRDfT8DPgdsU55/LXAVsGZ5PBN4d0s8+5XnzwBOKPe/Cby9pe+bgK3L/d2B61raXVrGvyNwT3n+QOAnwDqDvo8h+8mWrR1bhtqi13xU0kHl/ubA1sA8YDnw3fL8VGChy1WcJV0K/F157a+A/SSdUB6vDbxqpDe0PVfSq8vX7gPMkTS93H9DeQywDvBg+bJnbP+g3J8L7Dm4X0l/BrwJuLJ8Pbxw+Pwq2wbukrRZeW4f4ELbz5SxPVahn4jVKv9xRc+QtA/wZuBNtp+RdAtF4oDiH/oqP2oTRcVx36C+3zzSi2w/BVxJ8Y+7gP3Kvi60/Y+D+lqTIhEO6GPov6sCHrG98zBv++ygtsMZrZ+I1Sr3eKKXbAA8ViadHSiqjaEsAF4jaYsySRzWcu1a4NiBA0mvL3efAtYfqjNJe5RVBeVSPdsDS4AbgHdJ2qS8trGkEaun1vex/Tjw8MD9m/L+0E6jvP564H+X96WQtNE4+4kYtySe6CVXA+tKWgCcCvxiqEYupjkfQ5EY5gBPAE+Wl/8JWK+8OT+f4h4SwA+BncoJB4MnF2wL/KR85PTtwM+B79m+u+zvhnJSw3XAy0f5DJcCnx6YXAAcDvwfSXcC84EDRnqx7X+nmHk3R9I8ihl3jLWfiFWRJXMihiBpsu1lZcVzLsUab2fVHVdEE6TiiRja35cVwQKKm/7n1RxPRGOk4omIiI5KxRMRER2VxBMRER2VxBMRER2VxBMRER2VxBMRER2VxBMRER31P/3vskoEvgqAAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa0AAAEYCAYAAADvUanxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xu4XVV97vHvmyByFRFEEWgCGBVQQISgosIR8CBVsEoFilooEFvFR2vFokWKqFXx1BYVlaAcRCkKUjFqMGgEtV4IAcIlgdCIpISDRa4CUkL2fs8fc2xYbPZl7mSvPffc6/08z3wyL2ON9Vvz0f1jjDnmGLJNREREG0xrOoCIiIi6krQiIqI1krQiIqI1krQiIqI1krQiIqI1krQiIqI1krQiImLcSTpH0l2SbhzmuiR9TtIKSddL2qNOvUlaERHRDecCB41w/fXArLLNAb5Up9IkrYiIGHe2fwbcO0KRQ4HzXPk18ExJW49Wb5JWREQ0YRvg9o7jVeXciNbrWjgxPpZenHm2gJkHv7/pECaNhXtv1XQIk8bMv53TdAiTyvRXHK91qmAMf2/04sPeSdWtN2Cu7bnr9P01JGlFRMSYlQS1LknqDmC7juNty7kRpXswIiIAcF9f7W0czAPeUUYRvhx4wPado30oLa2IiBh3ki4A9gO2lLQK+EfgaQC2vwzMBw4GVgB/BI6pU2+SVkREVPrWjFtVto8c5bqBd4+13iStiIgAwP31k9a6jfhYe3mmFRERrZGWVkREVMZngEVXpaUVERGtkZZWREQA4HEciNEtaWlFRERrpKUVERGVFrS0krQiIgIY25D3pqR7MCIiWiNJawiSrpC0XNISSTdJmtNxbTNJ55XVNn9T9jcr16aVlThvlHSDpKskbd/cL4mIGIO+vvpbQ5K0CknrS9q449RRtncH9gE+LWn9cv6rwK22n297R+C3wFfKtcOB5wG72n4J8GfA/aX+zSfid0RETGU9/0xL0k7AccCby3btoCKbAA8DfZKeD7yMKjkNOA1YIWlHYGvgTtv9ALZXdZS7RNIDVAluvu3J33kcET0lQ94nKUkbSzpG0n8AZwPLqFpHnQnrfEnXA8uBj9nuA3YGlpR9AMr+EmAX4ELgjaVb8Z8lvbSjvv2AzwKHATdJ+qeSBCMiJoe+NfW3hvRk0gLuBI4FjrP9Kttftf3goDJH2d4V+BPgA5JmjFZpaVm9EPgQ0A8slLR/uWbbV9h+B1VrzcDNkt4yuB5JcyQtlrR47kU/WpffGRExpfRq9+BhVEnr3yV9E/ia7ZVDFbT9e0nXAHsD1wC7S5o20AUoaRqwO1VrDduPApcCl0r6b+BNwMJSdkOq51x/BTwTeC/wlKz0pBVBx7D8dUTEunB/5h6clGxfZvtw4NXAA8B3Jf1Y0szBZSVtBLwU+I3tFVTPvE7uKHIycI3tFZL2kPS88rlpwK7AynJ8OlVieyVwou09bZ9p+w/d+p0REVNNr7a0ALB9D3AGcIak2UDnf2acL+kR4OnAubavLuePBT4v6Tfl+FflHMBWwNmSnl6OFwFfKPtXAKfY/p+u/JiIiHXUhoEYPZ20Otle1LG/3wjl7gPeNsy1HwI/HOba/HUMMSKiu1qQtHqyezAiItopLa2IiAAyECMiImJcpaUVERGVPNOKiIgYP2lpRUQEkCHvERHRJi1IWukejIiI1khLKyIigAx5j4iIGFdpaUVERKUFz7SStCIiAgD3pXswIiJi3KSlNcnNPPj9TYcwKdw2/7NNhzBpHHHYJ5sOYdL49REfbzqESeW2lcev0+fb8J5WWloREdEaaWlFRESlf/K3tJK0IiICyECMiIjoYZIOkrRc0gpJJw1x/U8kXS7pWknXSzp4tDrT0oqIiMo4trQkTQfOBA4EVgFXSZpne1lHsZOBC21/SdLOwHxg5kj1pqUVERHdMBtYYftW26uBbwKHDipj4BllfzPg/41WaVpaEREBjG3Iu6Q5wJyOU3Ntz+043ga4veN4FbD3oGpOBS6T9B5gY+CA0b43SSsiIsasJKi5oxYc2ZHAubb/WdIrgK9LerHt/uE+kKQVERGV8R09eAewXcfxtuVcp2OBgwBs/0rSBsCWwF3DVZpnWhERAVRD3utuNVwFzJK0vaT1gSOAeYPK/BewP4CknYANgN+PVGmSVkREjDvba4ATgAXATVSjBJdKOk3SIaXY3wHHS7oOuAA42rZHqjfdgxERAYz/IpC251MNY+88d0rH/jJgn7HUmZZWRES0RlpaERFRacE0TklaEREBZO7BiIiIcZWktQ7KKJj3dRx/QtJ7JX1G0o2SbpB0eLm2n6Tvd5T9gqSjGwg7ImJI7uuvvTUlSWvdnAO8A0DSNKr3EFYBuwO7UU1J8hlJW4+lUklzJC2WtPjBhx4a55AjItorz7TWge3bJN0j6aXAc4BrgVcBF9juA/5b0k+BvYA/jKHex6dHmTljxojvLEREjJsGW1B1JWmtu68ARwPPpWp5HThMuTU8uWW7QXfDiogYmwzE6A3foZo7ay+qN79/DhwuabqkZwOvARYBK4GdJT1d0jMpU5dERER9aWmtI9urJV0O3G+7T9J3gFcA11GtFfNB278DkHQhcCPwW6quxIiIScN9k/9pRJLWOioDMF4O/DlAmTfrxLI9ie0PAh+c0AAjIqaQJK11UJaH/j7wHdv/2XQ8ERHrosmh7HUlaa2DMtnjDk3HERHRK5K0IiICSEsrIiJaxP2TfyBGhrxHRERrpKUVERFAO4a8p6UVERGtkZZWREQA4Mk/i9PoLS1JG0n6iKSzy/EsSW/ofmgRETGR3OfaW1PqdA/+X+BRqqmJAO4APt61iCIiIoZRJ2ntaPt04DEA238E1NWoIiJiwvX319+aUidprZa0IdXkr0jakarlFRERMaHqDMT4R+CHwHaSzgf2oVo/KmLCzDz4/U2HMGlMbzqAmLLaMBBj1KRl+0eSrqGayVzAe23f3fXIIiJiQrUhadUZPfhnwBrbP7D9fWCNpDd1P7SIiIgnq/NM6x9tPzBwYPt+qi7DiIiYQqbKQIyhyuSl5IiImHB1ks9iSZ8FzizH7wau7l5IERHRhCnxTAt4D7Aa+FbZHqVKXBEREROqzujBh4GTJiCWiIhoUH//5J83YtSkJekFwAeAmZ3lbb+2e2FFRMREa3KARV11nmldBHwZ+ArQgh7PiIiYquokrTW2v9T1SCIiolFTZSDG9yS9S9LWkp41sHU9soiIaDVJB0laLmmFpCHHRkh6q6RlkpZK+rfR6qzT0vrL8u+JHecM7FDjsxER0RLjORBD0nSqV6UOBFYBV0maZ3tZR5lZwIeAfWzfJ2mr0eqtM3pw+7UPOyIi2qJ/fLsHZwMrbN8KIOmbwKHAso4yxwNn2r4PwPZdo1Vad+XikyXNLcdZuTgiIkazDXB7x/Gqcq7TC4AXSPqFpF9LOmi0SuuuXLwaeGU5zsrFo5B0W9MxRESMVX+/am+S5kha3LHNWYuvXA+YBewHHAmcLemZo31gNDvaPlzSkVCtXCxp8r+BNslIWs/2mqbjiIgYD7bnAnNHKHIHsF3H8bblXKdVwJW2HwN+K+kWqiR21XCVZuXi7vg9gKT9JP1c0jxKP66kt0laJGmJpLPKw8qIiMa5X7W3Gq4CZknaXtL6wBHAvEFlLqFqZSFpS6ruwltHqrRO0jqVJ69cvBD4+zoR9yrbe3Uc7kG1cOYLJO0EHE41UmZ3qpe1jxr8+c5m94MPPTQxQUdEzxvPpUlKz9IJwALgJuBC20slnSbpkFJsAXCPpGXA5cCJtu8Zqd46owcvk3Q1Wbl4bS2y/duyvz/wMqqhnwAbAk8ZLdPZ7J45Y4YnKM6IiHFlez4wf9C5Uzr2Dby/bLXUmXtwoe39gR8McS5G93DHvoCv2f5QU8FERAynDRPmDts9KGmDMvPFlpI275gNYyZPHbYY9SwEDht4ga7czxkNxxQR0RojtbTeCbwPeB7Voo8DKfgPwBe6HNeUZHuZpJOByyRNAx6jWptsZbORRUS0o6U1bNKyfQZwhqT32P78BMY0Zdi+Arhi0LmBxTQjIiaVvjYnrQG2Py/plTx1Pa3zuhhXRETEU9QZiPF1YEdgCU+sp2UgSSsiYgppdfdghz2BncvQxIiIiMbUSVo3As8F7uxyLBER0aB+T42W1pbAMkmL6Ji+yfYhw38kIiJi/NVJWqd2O4iIiGhenemZmlZn9OBPywuws2z/WNJGQCZ5jYiYYvpa0D1YZxHI44FvA2eVU9tQzcwbERExoep0D76batnkKwFs/+fANEQRETF1tGHIe52lSR61vXrgQNJ6lLW1IiIiJlKdltZPJX0Y2FDSgcC7gO91N6yIiJhoU+KZFnAS1Uq8N1BNojsfOLmbQUVExMTrt2pvTakzerAfOFvS14BdgDsyO0ZEczbK4N3HPfj4zHLRK0ZaT+vLknYp+5tRzT14HnCtpCMnKL6IiJggfVbtrSkjdQ++2vbSsn8McIvtl1AtF//BrkcWERExyEjdg6s79g8ELgKw/Ttp8j+si4iIselrwYOfkZLW/ZLeANwB7AMcC48Ped9wAmKLiIgJ1PYJc98JfI5qhvf32f5dOb8/8INuBxYRETHYsEnL9i3AQUOcXwAs6GZQEREx8abKe1oRERGTQp0ZMSIioge0YSBGWloREdEadZYmeY6kr0q6tBzvLOnY7ocWERETqQ/V3ppSp6V1LtXAi+eV41uA93UroIiIaEaf629NqZO0trR9IdAPYHsNZMKviIiYeHUGYjwsaQvKGlqSXg480NWoIiJiwrWhNVInab0fmAfsKOkXwLOBw7oaVURExBBG7R60fQ2wL/BKqlkydrF9fbcDGy+SZkq6scvf8eG1+MzRkr7QjXgiItZG3xi2ptQZPfhuYBPbS23fCGwi6V3dD61Vxpy0IiImm6kyevB42/cPHNi+Dzi+eyHVI+ltkhZJWiLpLEkzJP2npC0lTZP0c0mvK8WnSzpb0lJJl0nasNRxvKSrJF0n6WJJG5Xz50o6rOO7Hir/bi3pZ+U7b5T0akmfAjYs584fJrbp5fwxkm6RtIhqEuKIiBiDOklrujrWIil/gNfvXkijk7QTcDiwj+3dqVqr+wKfBr4E/B2wzPZl5SOzgDNt7wLcD7ylnP9323vZ3g24iTKT/Qj+AlhQvnM3YIntk4BHbO9u+6hhYjtK0tbAR6mS1auAndf9TkREjJ8+u/bWlDpJawHwLUn7S9ofuAD4YXfDGtX+VItRXiVpSTnewfZXgGcAfw18oKP8b20vKftXAzPL/otLi+wG4Chgl1G+9yrgGEmnAi+x/WDd2IC9gSts/972auBbw32JpDmSFkta/OBDD40SUkTE5CTpIEnLJa2QdNII5d4iyZL2HK3OOqMHPwjMAf6mHP8I+EqtiLtHwNdsf+hJJ6vuvW3L4SbAQFJ5tKNYH0+sB3Yu8Cbb10k6GtivnF9DSeiSplFalrZ/Juk1wJ8C50r6rO3zasb2pro/zvZcYC7AzBkzWjAbWERMBeM5wKL0yp1JtYjwKqr/kJ9ne9mgcpsC7wWurFPviC2t8qVft/1l24eV7SzbTQ/nXwgcJmkrAEnPkjSDqnvwfOAU4Owa9WwK3CnpaVQtrQG3UbWWAA4Bnla+Zwbw37bPpkrce5Qyj5U6RortSmBfSVuUsn8+9p8dEdE94zx6cDawwvatpXfpm8ChQ5T7GNXf7v+pU+mISaskpxmSGn2GNVjJ1CcDl0m6nqr1NxPYC/i07fOB1ZKOGaWqj1Alk18AN3ecP5sqwVwHvAJ4uJzfD7hO0rVUz63OKOfnAtdLOn+Y2La2fSdwKvCr8n03rd2vj4hoXudjjLLNGVRkG+D2juNV5VxnHXsA29muvbCwPMoDNUnnATtRvWA88Mcb25+t+yWx9tI9GINtyvSmQ5g0HmzFHA4T57aVK9dpLPoZux1Y++/Ne6/70YjfVUZgH2T7uHL8dmBv2yeU42nAT4Cjbd8m6QrgA7YXj1RvnWdavynbNKrutIiIiNHcAWzXcbxtOTdgU+DFwBVlgPpzgXmSDhkpcY2atGx/dK3CjYiIVuljXDt2rgJmSdqeKlkdQfXaEAC2HwC2HDget5aWpMvhqb/E9mvrRh4REb3F9hpJJ1C9NjUdOMf2UkmnAYttz1ubeut0D3a+77QB1Yu5a9bmyyIiYvIa7yeEtucD8wedO2WYsvvVqbNO9+DVg079okxDFBERU0iTM13UVad78Fkdh9Oo3l/arGsRRUREDKNO9+DVVM+0RNUt+FtGn6MvIiJapg0vENTpHtx+IgKJiIgYTZ3uwadRzTv4mnLqCuAs2491Ma6IiJhg4zzkvSvqdA9+iWruvS+W47eXc8d1K6iIiJh4UyVpDaw3NeAnZU6+iIiICVVnPa0+STsOHEjagXY8r4uIiDEY51neu6JOS+tE4HJJt1KNIJwBjDZ7ekR0SSaJjV5WZ/TgQkmzgBeWU8ttPzrSZyIion1a/XKxpL2A223/zvajknanmsJppaRTbd87YVFGRETXtWEgxkjPtM4CVgOUJeY/BZwHPEBZCj4iImIijdQ9OL2jNXU4MNf2xcDFkpZ0P7SIiJhIbW9pTZc0kNT2p1phckCdARwRERHjaqTkcwHwU0l3A48APweQ9HyqLsKIiJhC+ts8EMP2JyQtBLYGLrMf/zXTgPdMRHARERGdRuzms/3rIc7d0r1wIiKiKW14ppVnUxERAbQjadWZxikiImJSSEsrIiKAdsyIkZZWRES0RlpaEREBtOOZVpJWREQA7XhPK92DERHRGmlpRUQE0I7uwbS0IiKiNdLSiogIIC2tWEuS5khaLGnxgw891HQ4EdEj+u3aW1OStCYh23Nt72l7z0032aTpcCIiJo10D0ZEBJDuwRiFpPmSntd0HBERbZGWVoNsH9x0DBERAzL3YERExDhKSysiIgDozzOtiIhoiz679laHpIMkLZe0QtJJQ1x/v6Rlkq6XtFDSjNHqTNKKiIhxJ2k6cCbwemBn4EhJOw8qdi2wp+1dgW8Dp49Wb5JWREQA4/5y8Wxghe1bba8Gvgkc2lnA9uW2/1gOfw1sO1qlSVoREdEN2wC3dxyvKueGcyxw6WiVZiBGREQAY3u5WNIcYE7Hqbm2567N90p6G7AnsO9oZZO0IiICgH731y5bEtRISeoOYLuO423LuSeRdADwD8C+th8d7XvTPRgREd1wFTBL0vaS1geOAOZ1FpD0UuAs4BDbd9WpNC2tiIgAxvc9LdtrJJ0ALACmA+fYXirpNGCx7XnAZ4BNgIskAfyX7UNGqjdJKyIiusL2fGD+oHOndOwfMNY6k7QiIgJox9yDSVoREQFkGqeIiIhxlZZWREQA1J3polFpaUVERGukpRUREQDUf7W4OWlpRUREa6SlFRERQDueaSVpRUQEkCHvERER4yotrYiIANrRPZiWVkREtEZaWhERAbTjmVaSVkREAO1IWukejIiI1khLKyIiAOif/A2ttLQiIqI9krSGIOkKScslLSnbtzuuzZF0c9kWSXpVx7U3SLpW0nWSlkl6ZzO/ICJi7Ppx7a0p6R4sJK0PPM32w+XUUbYXDyrzBuCdwKts3y1pD+ASSbOBe4C5wGzbqyQ9HZhZPre57fsm6rdERKyNDMRoAUk7SfpnYDnwglGK/z1wou27AWxfA3wNeDewKdV/BNxTrj1qe3n53OGSbpT0d5Ke3Y3fERHRC3oyaUnaWNIxkv4DOBtYBuxq+9qOYud3dA9+ppzbBbh6UHWLgV1s3wvMA1ZKukDSUZKmAdj+MvB6YCPgZ5K+LemggetDxDdH0mJJix986KFx+90RESOx629N6dXuwTuB64HjbN88TJmndA+OxvZxkl4CHAB8ADgQOLpcux34mKSPUyWwc6gS3iFD1DOXqquRmTNmTP72ekTEBOnJlhZwGHAH8O+STpE0o+bnlgEvG3TuZcDSgQPbN9j+F6qE9ZbOguXZ1xeBzwEXAh9au/AjIsZfGwZi9GTSsn2Z7cOBVwMPAN+V9GNJM0f56OnApyVtASBpd6qW1BclbSJpv46yuwMrS7nXSboe+DhwObCz7ffZXkpERNTWq92DANi+BzgDOKO0gvo6Lp8v6ZGyf7ftA2zPk7QN8EtJBh4E3mb7TkmbAh+UdBbwCPAwpWuQanDGG22vnICfFRGxVtrwLEJuwVT0vSzPtCKirttWrtS6fP5FM2bW/ntz88rb1um71lZPdg9GREQ79XT3YEREPKEN3TppaUVERGukpRUREUA7WlpJWhERAWTuwYiIiHGVllZERADt6B5MSysiIlojLa2IiADS0oqIiBbxGLY6yhJMyyWtkHTSENefLulb5fqVNeZ/TdKKiIjxJ2k6cCbVUkw7A0dK2nlQsWOB+2w/H/gX4NOj1ZukFRERwLi3tGYDK2zfans18E3g0EFlDqVa/R3g28D+kkac0zDPtCa5dZ0AczxImlMWpux5uRdPyL14wlS5F2P5eyNpDjCn49TcQfdgG+D2juNVwN6Dqnm8jO01kh4AtgDuHu5709KKOuaMXqRn5F48IffiCT13L2zPtb1nxzYhSTtJKyIiuuEOYLuO423LuSHLSFoP2Ixq/cFhJWlFREQ3XAXMkrS9pPWBI4B5g8rMA/6y7B8G/MSjLPKYZ1pRR+v76sdR7sUTci+ekHsxSHlGdQKwAJgOnGN7qaTTgMW25wFfBb4uaQVwL1ViG1FWLo6IiNZI92BERLRGklZERLRGklZERLRGklZERLRGklYMSdLpkp4h6WmSFkr6vaS3NR1XEyS9t9wLSfqqpGskva7puJqQezEySc9tOoapLkkrhvM6238A3gDcBjwfOLHRiJrzV+VevA7YHHg78KlmQ2pM7sXIvtp0AFNdklYMZ+Advj8FLrL9QJPBNGxgPraDga/bXtpxrtfkXozA9p82HcNUl6QVw/m+pJuBlwELJT0b+J+GY2rK1ZIuo/pDvUDSpkB/wzE1JfciGpWXi2NYkp4FPGC7T9JGwDNs/67puCaapGnA7sCttu+XtAWwje3rGw5twuVeRNMyjVMMSdI7OvY7L5038dE0zlSL2L0BOA3YGNig0YiaY2AX4CDgn4BN6N17EQ1ISyuGJOnzHYcbAPsD19g+rKGQGiPpS1RdYK+1vZOkzYHLbO/VcGgTrtwLAXvYnl1a4wt68V5EM9LSiiHZfk/nsaRnUq082ov2tr2HpGsBbN9XZq3uGZIOtf1d4OW2XyrpcgDb9/bavYhmZSBG1PUwsH3TQTTkMUnTKauMl0EpPTP4QNIhwG7lcHUv34toXlpaMSRJ36P8YaJaVmAn4MLmImrU54DvAFtJ+gTVuj8faTakCXVTWUYCqntxCbCdpH8C3gKc3Fhk0XPyTCuGJGnfjsM1wErbq5qKp2mSXkT1XE/AQts3NRxSY3IvoklJWjEsSc8BBh6wL7J9V5PxNEXS122/fbRzU13pFlxq+0VNxxK9K8+0YkiS3gosAv4ceCtwpaSeGzlY7NJ5IGk9qpeue4rtPmC5pD9pOpboXXmmFcP5B2CvgdZVeeD+Y+DbjUY1gSR9CPgwsKGkPwycBlbTu8urbw4slbSIanAOALYPaS6k6CVJWjGcaYO6A++hx1rmtj8JfFLS6cANwA62P1paGr06m/cGVC9ZDxDw6YZiiR6UpBXDuVTSAuCCcnw4ML/BeCacpH1t/xR4BvBy4LXAR4EHgYt54nlfL1mv3JPHSdqwqWCi9yRpxXDuAr5BNc8cwFzb32kwngkl6c3AjsBPgdl5uVh/A7wL2EFS5zyDmwK/aCaq6EVJWjGcjYGTgHuBbwG/bDacCfc74H+V/Z5+ubj4N+BS4JNU/7sY8KDte5sJKXpRhrzHiCTtStU1+BZgle0DGg5pwkiaafs2SUdR3YM9gK9RvVx8su2LGg0wogelpRWjuYuq1XEPsFXDsUwo27eVf8+XdDVPvFD7prxQG9GMtLRiSJLeRfV+1rOBi4ALbS9rNqqI6HVpacVwtgPeZ3tJ04FERAxISysiIlqjp14WjYiIdkvSioiI1sgzrYgOkrYAFpbD5wJ9wO/L8Wzbq7vwnXsAW9n+4RDXNgG+QjVpr4D7gP9t+49r8T1vBpbZvnkdQ45oTJJWRAfb91BmAZF0KvCQ7f9T9/OSppfZ0MdiD+DFwFOSFvC3wH/ZPqLU/yLgsTHWP+DNVC9FJ2lFa6V7MKImSd+TdLWkpZKOK+fWk3S/pH8t0xvNlnSIpOWl7OclXVLKbiLpXEmLJF0r6Y1l3r5TgKMkLRli+ZetgTsGDmzfbPuxUt9flrqWSPqipGkd8XxK0nWSfiVpK0mvBg4G/qWUnylplqQFJc6fSXpBqfcbks6Q9EtJt0r6s4578GFJN5S6P1HODVlPRFfYzpYt2xAbcCrwgY7jZ5V/NwKWUS3TsR7V9E5v7ri2CphB1Z13EXBJuXY6cETZ3xy4hWrW9OOAfx0mhpdRdU/+EvgY8Pxy/sVUy96vV47nAn/REc/ry/nPAieV/W9QvRg9UPflwI5lfx/gso5yF5T4dwVuLuffCPwc2HDQ/RiynmzZurGlezCivr+VNLBu1LZUE+ouoVpfa2Ay4Z2B5bZXAki6AHhHufY64PWSBubu2wAYcUFF21dL2qF89gBgsaTZZX+vcgywIXB7+dgjti8t+1cDrx5cr6RnUs1cf3H5PDz5ccEltg1cL2mbcu4A4Bzbj5TY7q1RT8S4yv+4ImqQdADwGuDlth+R9B9USQeqJFHnhceBKaB+M6ju14z0IdsDS6FcrCozvL7UdY7tjwyqaz2qJDqgj6H/fy7gbtu7D3EN4NFBZYczWj0R4yrPtCLq2Qy4tySsXRh+La1lwAslbVcSzOEd1xYA7xk4kPTSsvsg1RIfTyHpVaU1g6SnAzsBK6lWkX6rpC3LtS3K4pQjefx7bN8H3DnwvKo8D9ttlM//CPirgfWzJD1rLeuJWGtJWhH1/ADYSNIy4OPAlUMVcjUU/QSqpLIYuB94oFz+KLBxGciwlOqZGcBPgN3K4IzBAzFmAT+XdANwDfAr4Lu2byj1/bgMALkMeM4ov+HRKc60AAAAgklEQVQC4MMDAzGAI4C/lnQdsJQnr0g81G/7PtUIx8WSllCNbGSs9USsi0zjFDHOJG1i+6HS0joLuMH255uOK2IqSEsrYvz9TWmJLKMaIHF2w/FETBlpaUVERGukpRUREa2RpBUREa2RpBUREa2RpBUREa2RpBUREa2RpBUREa3x/wFz5uexnkUcUwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for sample in top_results:\n",
    "    plt.figure()\n",
    "    target_len = len(sample['sampled'])\n",
    "    source_len = len(sample['source'])\n",
    "\n",
    "    attention_matrix = sample['attention'][:target_len, :source_len+2].transpose()#[::-1]\n",
    "    ax = sns.heatmap(attention_matrix, center=0.0)\n",
    "    ylabs = [\"<BOS>\"]+sample['source']+[\"<EOS>\"]\n",
    "    #ylabs = sample['source']\n",
    "    #ylabs = ylabs[::-1]\n",
    "    ax.set_yticklabels(ylabs, rotation=0)\n",
    "    ax.set_xticklabels(sample['sampled'], rotation=90)\n",
    "    ax.set_xlabel(\"Target Sentence\")\n",
    "    ax.set_ylabel(\"Source Sentence\\n\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'source': \"you 're not as good as you think you are .\",\n",
       " 'truth': \"vous n'êtes pas aussi bonnes que vous pensez l'être .\",\n",
       " 'sampled': \"vous n'êtes pas très pas vous vous vous vous ?\"}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_source_sentence(vectorizer, batch_dict, index):\n",
    "    indices = batch_dict['x_source'][index].cpu().data.numpy()\n",
    "    vocab = vectorizer.source_vocab\n",
    "    return sentence_from_indices(indices, vocab)\n",
    "\n",
    "def get_true_sentence(vectorizer, batch_dict, index):\n",
    "    return sentence_from_indices(batch_dict['y_target'].cpu().data.numpy()[index], vectorizer.target_vocab)\n",
    "    \n",
    "def get_sampled_sentence(vectorizer, batch_dict, index):\n",
    "    y_pred = model(x_source=batch_dict['x_source'], \n",
    "                   x_source_lengths=batch_dict['x_source_length'], \n",
    "                   target_sequence=batch_dict['x_target'], \n",
    "                   sample_probability=1.0)\n",
    "    return sentence_from_indices(torch.max(y_pred, dim=2)[1].cpu().data.numpy()[index], vectorizer.target_vocab)\n",
    "\n",
    "def get_all_sentences(vectorizer, batch_dict, index):\n",
    "    return {\"source\": get_source_sentence(vectorizer, batch_dict, index), \n",
    "            \"truth\": get_true_sentence(vectorizer, batch_dict, index), \n",
    "            \"sampled\": get_sampled_sentence(vectorizer, batch_dict, index)}\n",
    "    \n",
    "def sentence_from_indices(indices, vocab, strict=True):\n",
    "    ignore_indices = set([vocab.mask_index, vocab.begin_seq_index, vocab.end_seq_index])\n",
    "    out = []\n",
    "    for index in indices:\n",
    "        if index == vocab.begin_seq_index and strict:\n",
    "            continue\n",
    "        elif index == vocab.end_seq_index and strict:\n",
    "            return \" \".join(out)\n",
    "        else:\n",
    "            out.append(vocab.lookup_index(index))\n",
    "    return \" \".join(out)\n",
    "\n",
    "results = get_all_sentences(vectorizer, batch_dict, 1)\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlpbook",
   "language": "python",
   "name": "nlpbook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
